[Mlir-commits] [mlir] Add missing FillOp to winograd lowering (PR #108181)
Thomas Preud'homme
llvmlistbot at llvm.org
Fri Sep 13 07:47:06 PDT 2024
https://github.com/RoboTux updated https://github.com/llvm/llvm-project/pull/108181
>From 7f44e5429b52736a25e4996290772a2cc5e2d2ce Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 10 Sep 2024 17:15:41 +0100
Subject: [PATCH 1/3] [Linalg] Add missing FillOp to winograd lowering
Winograd lowering involves a number of matmul and batch_matmul which
are currently passed tensor.empty result as out parameter, thereby
are undefined behaviour. This commit adds the necessary linalg.fill.
---
.../Linalg/Transforms/WinogradConv2D.cpp | 57 ++++--
.../transform-tile-and-winograd-rewrite.mlir | 179 +++++++++---------
.../Linalg/winograd-conv2d-rewrite.mlir | 96 +++++-----
mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 147 +++++++-------
4 files changed, 266 insertions(+), 213 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b65b18699a15aa..80edf4a32c6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
Value matmulRetValue = extractFilter;
+ Value zero = builder.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix G.
auto it = GMatrices.find(key);
@@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
retRows = GMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
- auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType);
+ auto empty =
+ builder
+ .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ .getResult();
+ auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
// Multiply G x g.
@@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
auto matmulType =
RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
- auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType);
+ auto empty =
+ builder
+ .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ .getResult();
+ auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
// Multiply u = (G x g) x GT.
@@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
int64_t retRows = 1;
int64_t retCols = 1;
Value matmulRetValue = extractInput;
+ Value zero = builder.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix BT.
auto it = BTMatrices.find(key);
@@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
retRows = BTMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
- auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType);
+ auto empty =
+ builder
+ .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ .getResult();
+ auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
Value BT =
create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
@@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
retCols = BMatrix.cols;
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
- auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType);
+ auto empty =
+ builder
+ .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ .getResult();
+ auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
Value B =
create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
// Multiply v = (BT x d) x B.
@@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
{inputShape[0] * inputShape[1],
inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
outputElementType);
- Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- outputElementType);
+ Value empty = rewriter
+ .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ outputElementType)
+ .getResult();
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(outputElementType));
+ Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
loc, matmulType, ValueRange({collapseInput, collapseFilter}),
@@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
int64_t leftScalarFactor = 1;
int64_t rightScalarFactor = 1;
Value matmulRetValue = extractValue;
+ Value zero = builder.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix AT.
auto it = ATMatrices.find(key);
@@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
leftScalarFactor = ATMatrix.scalarFactor;
retRows = ATMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
- auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType);
+ auto empty =
+ builder
+ .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ .getResult();
+ auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
// Multiply AT x m.
@@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
auto matmulType =
RankedTensorType::get({retRows, AMatrix.cols}, elementType);
retCols = AMatrix.cols;
- auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- elementType);
+ auto empty =
+ builder
+ .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ .getResult();
+ auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
// Multiply y = (AT x m) x A.
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 6bb3fb1423edc6..21dcea968615f6 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -44,9 +44,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK: %[[S11:.*]] = linalg.matmul
-// CHECK: %[[S13:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = linalg.matmul
+// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S9]]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
@@ -56,20 +56,20 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK: %[[S15:.*]] = linalg.matmul
-// CHECK: %[[S17:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S16:.*]] = linalg.matmul
+// CHECK: %[[S19:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_10]]
// CHECK: scf.yield %[[S13]]
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S9]]
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
// CHECK: %[[S6:.*]] = linalg.batch_matmul
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
@@ -78,20 +78,20 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S17:.*]] = linalg.matmul
-// CHECK: %[[S19:.*]] = linalg.matmul
-// CHECK: %[[S20:.*]] = tensor.empty()
-// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S18:.*]] = linalg.matmul
+// CHECK: %[[S21:.*]] = linalg.matmul
+// CHECK: %[[S22:.*]] = tensor.empty()
+// CHECK: %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<4x4xf32>
-// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_10]]
// CHECK: scf.yield %[[S15]]
// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
@@ -114,14 +114,15 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
%4 = tensor.empty() : tensor<36x18x2xf32>
- %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
- %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+ %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%5 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+ %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
%padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
tensor.yield %cst : f32
} : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
- %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
- %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+ %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %extracted_slice = tensor.extract_slice %7[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
return %extracted_slice : tensor<2x9x9x2xf32>
}
@@ -153,71 +154,72 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[S0:.*]] = tensor.empty()
-// CHECK: %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
-// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
// CHECK: %[[S13:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S16:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
-// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK: scf.yield %[[S10]] : tensor<6x6x5x2xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
-// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
-// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK: %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S11]], %[[S12]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S14:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
// CHECK: %[[S17:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S20:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
-// CHECK: scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[S14]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
-// CHECK: scf.yield %[[S9]]
+// CHECK: scf.yield %[[S10]]
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK: %[[S6:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK: %[[S7:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
-// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
-// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S8:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK: %[[S9:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S8]])
+// CHECK: %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S11]], %[[S12]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
// CHECK: %[[S19:.*]] = linalg.matmul
-// CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: %[[S22:.*]] = linalg.matmul
+// CHECK: %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK: %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S23]] : tensor<4x4xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<4x4xf32>
-// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE_12]]
-// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32>
-// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[S16]] : tensor<2x4x4x2xf32>
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S15:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, %[[S14]], %[[S15]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
-// CHECK: scf.yield %[[S9]]
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[S10]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S9]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
// CHECK: return %[[EXTRACTED_SLICE]]
// -----
func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<6x1x5x2xf32>
%1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
%2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
@@ -225,10 +227,11 @@ func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
%collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
%collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
%4 = tensor.empty() : tensor<6x2x2xf32>
- %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
- %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
- %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
- return %6 : tensor<2x4x1x2xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+ %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+ %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+ %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ return %7 : tensor<2x4x1x2xf32>
}
module attributes {transform.with_named_sequence} {
@@ -260,33 +263,33 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
-// CHECK: %[[S9:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S7]]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
-// CHECK: %[[S9:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S7]]
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK: %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK: %[[S5:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
-// CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
-// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
+// CHECK: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S9:.*]] = linalg.matmul
-// CHECK: %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
-// CHECK: %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
+// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S12:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK: %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x1xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<4x1xf32>
-// CHECK: %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
-// CHECK: scf.yield %[[S7]]
-// CHECK: return %[[S6]]
+// CHECK: scf.yield %[[S8]]
+// CHECK: return %[[S7]]
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 095a6636b68dc6..2ffd9fd9c0db21 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -13,14 +13,15 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
%collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
%collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
%6 = tensor.empty() : tensor<36x18x2xf32>
- %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
- %expanded = tensor.expand_shape %7 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+ %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+ %8 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%7 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+ %expanded = tensor.expand_shape %8 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
%padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
- %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
- %extracted_slice = tensor.extract_slice %8[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+ %9 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %extracted_slice = tensor.extract_slice %9[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
return %extracted_slice : tensor<2x9x9x2xf32>
}
@@ -44,16 +45,18 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-DAG: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], %[[C0]], %[[C0]], %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<3x3xf32>
-// CHECK-NEXT: %[[S8:.*]] = tensor.empty() : tensor<6x3xf32>
-// CHECK-NEXT: %[[S9:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S8]] : tensor<6x3xf32>) -> tensor<6x3xf32>
-// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT: %[[S10:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S9]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT: %[[S13:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT: %[[S14:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S7]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT: scf.yield %[[S8]] : tensor<6x6x5x2xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
@@ -61,60 +64,65 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK-NEXT: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S10]], %[[S11]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
-// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT: %[[S15:.*]] = linalg.matmul ins(%[[S13]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT: %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT: %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S11]], %[[S12]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
+// CHECK-NEXT: %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT: %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT: %[[S15:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT: %[[S16:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT: %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT: %[[S18:.*]] = linalg.matmul ins(%[[S15]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S18]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: scf.yield %[[S10]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S7]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
// CHECK-NEXT: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
// CHECK-NEXT: tensor.yield %[[CST_6]] : f32
// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32>
-// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
-// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
-// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[S11:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT: %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<4x6xf32>) -> tensor<4x6xf32>
// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT: %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT: %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[S16:.*]] = linalg.matmul ins(%[[S13]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S15]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[S17:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT: %[[S18:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S17]] : tensor<4x4xf32>) {
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[IN]] : f32
// CHECK-NEXT: } -> tensor<4x4xf32>
-// CHECK-NEXT: %[[S16:.*]] = linalg.mul ins(%[[S15]], %[[S13]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT: %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK-NEXT: %[[S18:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG10]][%[[ARG7]], %[[S17]], %[[S18]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[S19:.*]] = linalg.mul ins(%[[S18]], %[[S16]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S17]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT: %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S9]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT: scf.yield %[[S10]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S8]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT: scf.yield %[[S9]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S7]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT: scf.yield %[[S8]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index ec11a6ef8fbeee..1186bf8fe5aced 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -7,17 +7,19 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
// CHECK-LABEL: func.func @conv2d_4x4_3x3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: return %[[S7]] : tensor<2x4x4x2xf32>
// CHECK-NEXT: }
// -----
@@ -29,17 +31,19 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
// CHECK-LABEL: func.func @conv2d_2x2_5x5
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT: return %[[S7]] : tensor<2x2x2x2xf32>
// CHECK-NEXT: }
// -----
@@ -51,17 +55,19 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
// CHECK-LABEL: func.func @conv2d_1x4_1x3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S0]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S2]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT: return %[[S7]] : tensor<2x1x4x2xf32>
// CHECK-NEXT: }
// -----
@@ -73,17 +79,19 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
// CHECK-LABEL: func.func @conv2d_4x1_3x1
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S0]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S2]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT: return %[[S7]] : tensor<2x4x1x2xf32>
// CHECK-NEXT: }
// -----
@@ -95,17 +103,19 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
// CHECK-LABEL: func.func @conv2d_aligned
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S2]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: return %[[S7]] : tensor<2x8x8x2xf32>
// CHECK-NEXT: }
// -----
@@ -117,8 +127,8 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-LABEL: func.func @conv2d_unaligned
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0
@@ -127,16 +137,17 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
// CHECK-NEXT: %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0
// CHECK-NEXT: tensor.yield %[[CST]] : f32
// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
@@ -149,17 +160,19 @@ func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3
// CHECK-LABEL: func.func @conv2d_type_promotion
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT: return %[[S6]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: return %[[S7]] : tensor<2x4x4x2xf32>
// CHECK-NEXT: }
// -----
>From 81540b1384d8fec5dd572ef489cb08318ff82555 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Thu, 12 Sep 2024 09:32:25 +0100
Subject: [PATCH 2/3] Clean up tests
Reduce diff noise in tests and add tensor.empty and linalg.fill check in
decomposition tests.
---
.../transform-tile-and-winograd-rewrite.mlir | 190 +++++++++++-------
.../Linalg/winograd-conv2d-rewrite.mlir | 38 ++--
mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 130 ++++++------
3 files changed, 204 insertions(+), 154 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 21dcea968615f6..78d2e49cf6bb4b 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -36,6 +36,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func.func @conv2d
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<6x4xf32>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK: %[[CST_3:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<{{.*}}> : tensor<3x6xf32>
+// CHECK: %[[CST_5:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
+// CHECK: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
@@ -44,8 +51,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK: %[[S12:.*]] = linalg.matmul
-// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK: %[[S11:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK: %[[S12:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK: %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK: %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S15:.*]] = linalg.matmul ins(%[[S12]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%14 : tensor<6x6xf32>) -> tensor<6x6xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S9]]
@@ -56,20 +67,24 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK: %[[S16:.*]] = linalg.matmul
-// CHECK: %[[S19:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: scf.yield %[[INSERTED_SLICE_10]]
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK: %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S16:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_8]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S17:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK: %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S19:.*]] = linalg.matmul ins(%[[S16]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
// CHECK: scf.yield %[[S13]]
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S9]]
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
// CHECK: %[[S6:.*]] = linalg.batch_matmul
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
@@ -78,20 +93,24 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S18:.*]] = linalg.matmul
-// CHECK: %[[S21:.*]] = linalg.matmul
-// CHECK: %[[S22:.*]] = tensor.empty()
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S16:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK: %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK: %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_8]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK: %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK: %[[S20:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[S21:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = tensor.empty() : tensor<4x4xf32>
// CHECK: %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<4x4xf32>
// CHECK: %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK: scf.yield %[[INSERTED_SLICE_10]]
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
// CHECK: scf.yield %[[S15]]
// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
@@ -148,72 +167,91 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func.func @conv2d_unaligned
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<6x4xf32>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK: %[[CST_3:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[CST_4:.*]] = arith.constant dense<{{.*}}> : tensor<3x6xf32>
+// CHECK: %[[CST_5:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[S0:.*]] = tensor.empty()
-// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
-// CHECK: %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK: %[[S13:.*]] = linalg.matmul
-// CHECK: %[[S16:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK: %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK: %[[S13:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S12]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK: %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK: %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S16:.*]] = linalg.matmul ins(%[[S13]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
-// CHECK: scf.yield %[[S10]] : tensor<6x6x5x2xf32>
+// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
-// CHECK: %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S11]], %[[S12]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK: %[[S14:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK: %[[S17:.*]] = linalg.matmul
-// CHECK: %[[S20:.*]] = linalg.matmul
-// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S15:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK: %[[S16:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S17:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_11]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S16]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S18:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK: %[[S19:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S18]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S20:.*]] = linalg.matmul ins(%[[S17]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
-// CHECK: scf.yield %[[S14]] : tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
-// CHECK: scf.yield %[[S10]]
+// CHECK: scf.yield %[[S9]]
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK: %[[S7:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK: %[[S8:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK: %[[S9:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S8]])
-// CHECK: %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S11]], %[[S12]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK: %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK: %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S19:.*]] = linalg.matmul
-// CHECK: %[[S22:.*]] = linalg.matmul
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK: %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK: %[[S19:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_11]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK: %[[S21:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = linalg.matmul ins(%[[S19]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
// CHECK: %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S23]] : tensor<4x4xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<4x4xf32>
// CHECK: %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE_12]]
-// CHECK: scf.yield %[[S16]] : tensor<2x4x4x2xf32>
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S15:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, %[[S14]], %[[S15]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
-// CHECK: scf.yield %[[S10]]
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S9]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
// CHECK: return %[[EXTRACTED_SLICE]]
// -----
@@ -255,15 +293,21 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func.func @conv2d_mx1_rx1
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
-// CHECK: %[[S10:.*]] = linalg.matmul
+// CHECK: %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK: %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S7]]
@@ -271,18 +315,24 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
-// CHECK: %[[S10:.*]] = linalg.matmul
+// CHECK: %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK: %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[S10:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S7]]
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK: %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK: %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_3]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
-// CHECK: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
-// CHECK: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S9:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK: %[[S10:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S9]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK: %[[S11:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
// CHECK: %[[S12:.*]] = tensor.empty() : tensor<4x1xf32>
// CHECK: %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x1xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
@@ -291,5 +341,5 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
-// CHECK: scf.yield %[[S8]]
-// CHECK: return %[[S7]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: return %[[S6]]
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 2ffd9fd9c0db21..4369f5f1eab4ca 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -45,7 +45,7 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-DAG: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], %[[C0]], %[[C0]], %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<3x3xf32>
// CHECK-NEXT: %[[S9:.*]] = tensor.empty() : tensor<6x3xf32>
// CHECK-NEXT: %[[S10:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S9]] : tensor<6x3xf32>) -> tensor<6x3xf32>
@@ -56,7 +56,7 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S8]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT: scf.yield %[[S7]] : tensor<6x6x5x2xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
@@ -64,12 +64,12 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK-NEXT: %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S11]], %[[S12]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
+// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S10]], %[[S11]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
// CHECK-NEXT: %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
// CHECK-NEXT: %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
// CHECK-NEXT: %[[S15:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
@@ -79,11 +79,11 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S18]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S10]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: scf.yield %[[S7]] : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
// CHECK-NEXT: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
@@ -95,10 +95,10 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
// CHECK-NEXT: tensor.yield %[[CST_6]] : f32
// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT: %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32>
// CHECK-NEXT: %[[S11:.*]] = tensor.empty() : tensor<4x6xf32>
// CHECK-NEXT: %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<4x6xf32>) -> tensor<4x6xf32>
@@ -117,12 +117,12 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S10]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT: scf.yield %[[S9]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S9]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT: scf.yield %[[S8]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield %[[S8]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT: scf.yield %[[S7]] : tensor<2x12x12x2xf32>
// CHECK-NEXT: }
-// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index 1186bf8fe5aced..0040d81a2d24e7 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -7,19 +7,19 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
// CHECK-LABEL: func.func @conv2d_4x4_3x3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT: return %[[S7]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: return %[[S9]] : tensor<2x4x4x2xf32>
// CHECK-NEXT: }
// -----
@@ -31,19 +31,19 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
// CHECK-LABEL: func.func @conv2d_2x2_5x5
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-// CHECK-NEXT: return %[[S7]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT: return %[[S9]] : tensor<2x2x2x2xf32>
// CHECK-NEXT: }
// -----
@@ -56,18 +56,18 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
// CHECK-LABEL: func.func @conv2d_1x4_1x3
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S0]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S2]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-// CHECK-NEXT: return %[[S7]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT: return %[[S9]] : tensor<2x1x4x2xf32>
// CHECK-NEXT: }
// -----
@@ -80,18 +80,18 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
// CHECK-LABEL: func.func @conv2d_4x1_3x1
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S0]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S2]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-// CHECK-NEXT: return %[[S7]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT: return %[[S9]] : tensor<2x4x1x2xf32>
// CHECK-NEXT: }
// -----
@@ -104,18 +104,18 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
// CHECK-LABEL: func.func @conv2d_aligned
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S2]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x8x2xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
-// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: return %[[S7]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: return %[[S9]] : tensor<2x8x8x2xf32>
// CHECK-NEXT: }
// -----
@@ -127,8 +127,8 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-LABEL: func.func @conv2d_unaligned
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
// CHECK-NEXT: ^bb0
@@ -137,7 +137,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
// CHECK-NEXT: %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
// CHECK-NEXT: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
>From a4afe577a93ba839e8c6be65765761bd68a4dbce Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at celest.fr>
Date: Fri, 13 Sep 2024 15:46:57 +0100
Subject: [PATCH 3/3] =?UTF-8?q?Fix=20use=20of=20hardcoded=20SSA=C2=A0value?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Max191 <44243577+Max191 at users.noreply.github.com>
---
.../Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 78d2e49cf6bb4b..c5760acf94a88a 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -56,7 +56,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[S12:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
// CHECK: %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
// CHECK: %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK: %[[S15:.*]] = linalg.matmul ins(%[[S12]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%14 : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK: %[[S15:.*]] = linalg.matmul ins(%[[S12]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S9]]
More information about the Mlir-commits
mailing list