[Mlir-commits] [mlir] Add missing FillOp to winograd lowering (PR #108181)

Thomas Preud'homme llvmlistbot at llvm.org
Fri Sep 13 07:47:06 PDT 2024


https://github.com/RoboTux updated https://github.com/llvm/llvm-project/pull/108181

>From 7f44e5429b52736a25e4996290772a2cc5e2d2ce Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 10 Sep 2024 17:15:41 +0100
Subject: [PATCH 1/3] [Linalg] Add missing FillOp to winograd lowering

Winograd lowering involves a number of matmul and batch_matmul which
are currently passed tensor.empty result as out parameter, thereby
are undefined behaviour. This commit adds the necessary linalg.fill.
---
 .../Linalg/Transforms/WinogradConv2D.cpp      |  57 ++++--
 .../transform-tile-and-winograd-rewrite.mlir  | 179 +++++++++---------
 .../Linalg/winograd-conv2d-rewrite.mlir       |  96 +++++-----
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 147 +++++++-------
 4 files changed, 266 insertions(+), 213 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b65b18699a15aa..80edf4a32c6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
     TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     Value matmulRetValue = extractFilter;
+    Value zero = builder.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix G.
       auto it = GMatrices.find(key);
@@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
       retRows = GMatrix.rows;
       auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
       // Multiply G x g.
@@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
       auto matmulType =
           RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
       // Multiply u = (G x g) x GT.
@@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
     int64_t retRows = 1;
     int64_t retCols = 1;
     Value matmulRetValue = extractInput;
+    Value zero = builder.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix BT.
       auto it = BTMatrices.find(key);
@@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
       retRows = BTMatrix.rows;
       auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value BT =
           create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
@@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
       retCols = BMatrix.cols;
       auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
       Value B =
           create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
       // Multiply v = (BT x d) x B.
@@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
       {inputShape[0] * inputShape[1],
        inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
       outputElementType);
-  Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                outputElementType);
+  Value empty = rewriter
+                    .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                             outputElementType)
+                    .getResult();
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(outputElementType));
+  Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
   auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
       loc, matmulType, ValueRange({collapseInput, collapseFilter}),
@@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
     int64_t leftScalarFactor = 1;
     int64_t rightScalarFactor = 1;
     Value matmulRetValue = extractValue;
+    Value zero = builder.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix AT.
       auto it = ATMatrices.find(key);
@@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
       leftScalarFactor = ATMatrix.scalarFactor;
       retRows = ATMatrix.rows;
       auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
       // Multiply AT x m.
@@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
       auto matmulType =
           RankedTensorType::get({retRows, AMatrix.cols}, elementType);
       retCols = AMatrix.cols;
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
+      auto empty =
+          builder
+              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+              .getResult();
+      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
 
       Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
       // Multiply y = (AT x m) x A.
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 6bb3fb1423edc6..21dcea968615f6 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -44,9 +44,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK:      %[[S11:.*]] = linalg.matmul
-// CHECK:      %[[S13:.*]] = linalg.matmul
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = linalg.matmul
+// CHECK:      %[[S15:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
@@ -56,20 +56,20 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:      %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S15:.*]] = linalg.matmul
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S16:.*]] = linalg.matmul
+// CHECK:          %[[S19:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_10]]
 // CHECK:        scf.yield %[[S13]]
 // CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
 // CHECK:  %[[S6:.*]] = linalg.batch_matmul
 // CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
 // CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
@@ -78,20 +78,20 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK:      %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
 // CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[S19:.*]] = linalg.matmul
-// CHECK:          %[[S20:.*]] = tensor.empty()
-// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S18:.*]] = linalg.matmul
+// CHECK:          %[[S21:.*]] = linalg.matmul
+// CHECK:          %[[S22:.*]] = tensor.empty()
+// CHECK:          %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
 // CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:            linalg.yield %[[IN]] : f32
 // CHECK:          } -> tensor<4x4xf32>
-// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
+// CHECK:          %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_10]]
 // CHECK:        scf.yield %[[S15]]
 // CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
@@ -114,14 +114,15 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
   %4 = tensor.empty() : tensor<36x18x2xf32>
-  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
-  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%5 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
   %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
   ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
     tensor.yield %cst : f32
   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-  %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %7[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
   return %extracted_slice : tensor<2x9x9x2xf32>
 }
 
@@ -153,71 +154,72 @@ module attributes {transform.with_named_sequence} {
 // CHECK:  %[[C2:.*]] = arith.constant 2 : index
 // CHECK:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK:  %[[S0:.*]] = tensor.empty()
-// CHECK:  %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK:      %[[S11:.*]] = linalg.matmul
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
 // CHECK:      %[[S13:.*]] = linalg.matmul
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:      %[[S16:.*]] = linalg.matmul
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
-// CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK:    scf.yield %[[S10]] : tensor<6x6x5x2xf32>
 // CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
-// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S15:.*]] = linalg.matmul
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S11]], %[[S12]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S14:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
 // CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S20:.*]] = linalg.matmul
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:        scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:        scf.yield %[[S14]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S9]]
+// CHECK:    scf.yield %[[S10]]
 // CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
 // CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK:  %[[S6:.*]] = linalg.batch_matmul
-// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK:  %[[S7:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
 // CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
-// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
-// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          %[[S17:.*]] = linalg.matmul
+// CHECK:  %[[S8:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK:  %[[S9:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S8]])
+// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S11]], %[[S12]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
 // CHECK:          %[[S19:.*]] = linalg.matmul
-// CHECK:          %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK:          %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK:          %[[S22:.*]] = linalg.matmul
+// CHECK:          %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:          %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S23]] : tensor<4x4xf32>) {
 // CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:            linalg.yield %[[IN]] : f32
 // CHECK:          } -> tensor<4x4xf32>
-// CHECK:          %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_12]]
-// CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
-// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:        scf.yield %[[S16]] : tensor<2x4x4x2xf32>
+// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:      %[[S15:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, %[[S14]], %[[S15]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S9]]
-// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK:    scf.yield %[[S10]]
+// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S9]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
 // CHECK:  return %[[EXTRACTED_SLICE]]
 
 // -----
 
 func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<6x1x5x2xf32>
   %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
   %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
@@ -225,10 +227,11 @@ func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
   %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
   %4 = tensor.empty() : tensor<6x2x2xf32>
-  %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-  %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-  %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-  return %6 : tensor<2x4x1x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+  return %7 : tensor<2x4x1x2xf32>
 }
 
 module attributes {transform.with_named_sequence} {
@@ -260,33 +263,33 @@ module attributes {transform.with_named_sequence} {
 // CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
 // CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
-// CHECK:       %[[S9:.*]] = linalg.matmul
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S10:.*]] = linalg.matmul
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
 // CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
 // CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
-// CHECK:       %[[S9:.*]] = linalg.matmul
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       %[[S10:.*]] = linalg.matmul
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:   %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK:   %[[S5:.*]] = linalg.batch_matmul
-// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
-// CHECK:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
-// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:   %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:   %[[S6:.*]] = linalg.batch_matmul
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
+// CHECK:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:       %[[S9:.*]] = linalg.matmul
-// CHECK:       %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
-// CHECK:       %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
+// CHECK:       %[[S11:.*]] = linalg.matmul
+// CHECK:       %[[S12:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x1xf32>) {
 // CHECK:       ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:         linalg.yield %[[IN]] : f32
 // CHECK:       } -> tensor<4x1xf32>
-// CHECK:       %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
-// CHECK:     scf.yield %[[S7]]
-// CHECK:   return %[[S6]]
+// CHECK:     scf.yield %[[S8]]
+// CHECK:   return %[[S7]]
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 095a6636b68dc6..2ffd9fd9c0db21 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -13,14 +13,15 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
   %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
   %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
   %6 = tensor.empty() : tensor<36x18x2xf32>
-  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
-  %expanded = tensor.expand_shape %7 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+  %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %8 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%7 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+  %expanded = tensor.expand_shape %8 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
   %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
   ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
     tensor.yield %cst : f32
   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-  %extracted_slice = tensor.extract_slice %8[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  %9 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %9[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
   return %extracted_slice : tensor<2x9x9x2xf32>
 }
 
@@ -44,16 +45,18 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
 // CHECK-NEXT:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
 // CHECK-NEXT:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], %[[C0]], %[[C0]], %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<3x3xf32>
-// CHECK-NEXT:       %[[S8:.*]] = tensor.empty() : tensor<6x3xf32>
-// CHECK-NEXT:       %[[S9:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S8]] : tensor<6x3xf32>) -> tensor<6x3xf32>
-// CHECK-NEXT:       %[[S10:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:       %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
+// CHECK-NEXT:       %[[S9:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:       %[[S10:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S9]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:       %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:       %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:       %[[S13:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:       %[[S14:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
 // CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
 // CHECK-NEXT:     }
-// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     scf.yield %[[S8]] : tensor<6x6x5x2xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
 // CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
@@ -61,60 +64,65 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:           %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK-NEXT:           %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S10]], %[[S11]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
-// CHECK-NEXT:           %[[S12:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:           %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
-// CHECK-NEXT:           %[[S15:.*]] = linalg.matmul ins(%[[S13]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:         %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:           %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S11]], %[[S12]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S15:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S16:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S18:.*]] = linalg.matmul ins(%[[S15]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S18]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:         }
-// CHECK-NEXT:         scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:         scf.yield %[[S10]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:       }
-// CHECK-NEXT:       scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:     }
-// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:     scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
 // CHECK-NEXT:   %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
 // CHECK-NEXT:   %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
 // CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
 // CHECK-NEXT:     tensor.yield %[[CST_6]] : f32
 // CHECK-NEXT:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:         %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
 // CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32>
-// CHECK-NEXT:           %[[S10:.*]] = tensor.empty() : tensor<4x6xf32>
-// CHECK-NEXT:           %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32>
-// CHECK-NEXT:           %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S11:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:           %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<4x6xf32>) -> tensor<4x6xf32>
 // CHECK-NEXT:           %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:           %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S14]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S16:.*]] = linalg.matmul ins(%[[S13]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S15]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S17:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S18:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S17]] : tensor<4x4xf32>) {
 // CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK-NEXT:             linalg.yield %[[IN]] : f32
 // CHECK-NEXT:           } -> tensor<4x4xf32>
-// CHECK-NEXT:           %[[S16:.*]] = linalg.mul ins(%[[S15]], %[[S13]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:           %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK-NEXT:           %[[S18:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG10]][%[[ARG7]], %[[S17]], %[[S18]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:           %[[S19:.*]] = linalg.mul ins(%[[S18]], %[[S16]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S17]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
 // CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:         }
-// CHECK-NEXT:         scf.yield %[[S9]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:         scf.yield %[[S10]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:       }
-// CHECK-NEXT:       scf.yield %[[S8]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:       scf.yield %[[S9]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:     }
-// CHECK-NEXT:     scf.yield %[[S7]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:     scf.yield %[[S8]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:   }
-// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index ec11a6ef8fbeee..1186bf8fe5aced 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -7,17 +7,19 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_4x4_3x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:  %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S7]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -29,17 +31,19 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_2x2_5x5
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S7]] : tensor<2x2x2x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -51,17 +55,19 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_1x4_1x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S0]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S2]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S7]] : tensor<2x1x4x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -73,17 +79,19 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_4x1_3x1
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-// CHECK-NEXT:   return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S0]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S2]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S7]] : tensor<2x4x1x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -95,17 +103,19 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
 
 // CHECK-LABEL: func.func @conv2d_aligned
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
-// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK-NEXT:  %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:  return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:  %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S2]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  return %[[S7]] : tensor<2x8x8x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -117,8 +127,8 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 
 // CHECK-LABEL: func.func @conv2d_unaligned
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:  %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
 // CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
 // CHECK-NEXT:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
 // CHECK-NEXT:  ^bb0
@@ -127,16 +137,17 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
 // CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
 // CHECK-NEXT:  %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
 // CHECK-NEXT:  ^bb0
 // CHECK-NEXT:    tensor.yield %[[CST]] : f32
 // CHECK-NEXT:  } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:  return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
 
@@ -149,17 +160,19 @@ func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3
 
 // CHECK-LABEL: func.func @conv2d_type_promotion
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK:        %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
 // CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
 // CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
 // CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
 // CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
 // CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:   return %[[S6]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:   return %[[S7]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT: }
 
 // -----

>From 81540b1384d8fec5dd572ef489cb08318ff82555 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Thu, 12 Sep 2024 09:32:25 +0100
Subject: [PATCH 2/3] Clean up tests

Reduce diff noise in tests and add tensor.empty and linalg.fill check in
decomposition tests.
---
 .../transform-tile-and-winograd-rewrite.mlir  | 190 +++++++++++-------
 .../Linalg/winograd-conv2d-rewrite.mlir       |  38 ++--
 mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 130 ++++++------
 3 files changed, 204 insertions(+), 154 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 21dcea968615f6..78d2e49cf6bb4b 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -36,6 +36,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func.func @conv2d
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
 // CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<6x4xf32>
+// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK:  %[[CST_3:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK:  %[[CST_4:.*]] = arith.constant dense<{{.*}}> : tensor<3x6xf32>
+// CHECK:  %[[CST_5:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
+// CHECK:  %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK:  %[[C5:.*]] = arith.constant 5 : index
 // CHECK:  %[[C2:.*]] = arith.constant 2 : index
@@ -44,8 +51,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = linalg.matmul
-// CHECK:      %[[S15:.*]] = linalg.matmul
+// CHECK:      %[[S10:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK:      %[[S11:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK:      %[[S12:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK:      %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK:      %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:      %[[S15:.*]] = linalg.matmul ins(%[[S12]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%14 : tensor<6x6xf32>) -> tensor<6x6xf32>
 // CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
@@ -56,20 +67,24 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S16:.*]] = linalg.matmul
-// CHECK:          %[[S19:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_10]]
+// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK:          %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[S16:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_8]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[S17:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK:          %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[S19:.*]] = linalg.matmul ins(%[[S16]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
 // CHECK:        scf.yield %[[S13]]
 // CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
 // CHECK:  %[[S6:.*]] = linalg.batch_matmul
 // CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
 // CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
@@ -78,20 +93,24 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_8]])
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
 // CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          %[[S18:.*]] = linalg.matmul
-// CHECK:          %[[S21:.*]] = linalg.matmul
-// CHECK:          %[[S22:.*]] = tensor.empty()
+// CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S16:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK:          %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK:          %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_8]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK:          %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:          %[[S20:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[S21:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[S22:.*]] = tensor.empty() : tensor<4x4xf32>
 // CHECK:          %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
 // CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:            linalg.yield %[[IN]] : f32
 // CHECK:          } -> tensor<4x4xf32>
 // CHECK:          %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK:          scf.yield %[[INSERTED_SLICE_10]]
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          scf.yield %[[INSERTED_SLICE_9]]
 // CHECK:        scf.yield %[[S15]]
 // CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
@@ -148,72 +167,91 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func.func @conv2d_unaligned
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
 // CHECK:  %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<6x4xf32>
+// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK:  %[[CST_3:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
 // CHECK:  %[[C3:.*]] = arith.constant 3 : index
+// CHECK:  %[[CST_4:.*]] = arith.constant dense<{{.*}}> : tensor<3x6xf32>
+// CHECK:  %[[CST_5:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
 // CHECK:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK:  %[[C5:.*]] = arith.constant 5 : index
 // CHECK:  %[[C2:.*]] = arith.constant 2 : index
 // CHECK:  %[[C0:.*]] = arith.constant 0 : index
+// CHECK:  %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:  %[[S0:.*]] = tensor.empty()
-// CHECK:  %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
-// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
-// CHECK:      %[[S13:.*]] = linalg.matmul
-// CHECK:      %[[S16:.*]] = linalg.matmul
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK:      %[[S11:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK:      %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK:      %[[S13:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S12]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK:      %[[S14:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK:      %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:      %[[S16:.*]] = linalg.matmul ins(%[[S13]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
-// CHECK:    scf.yield %[[S10]] : tensor<6x6x5x2xf32>
+// CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
 // CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
-// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:      %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S11]], %[[S12]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK:        %[[S14:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
-// CHECK:          %[[S17:.*]] = linalg.matmul
-// CHECK:          %[[S20:.*]] = linalg.matmul
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK:          %[[S15:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK:          %[[S16:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S15]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[S17:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_11]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S16]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[S18:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK:          %[[S19:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S18]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[S20:.*]] = linalg.matmul ins(%[[S17]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S20]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:        scf.yield %[[S14]] : tensor<6x6x1x1x2x5xf32>
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:        scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S10]]
+// CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
 // CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK:  %[[S7:.*]] = linalg.batch_matmul
-// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK:  %[[S6:.*]] = linalg.batch_matmul
+// CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
 // CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK:  %[[S8:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK:  %[[S9:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S8]])
-// CHECK:    %[[S10:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
-// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
-// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:      %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S11]], %[[S12]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
-// CHECK:      %[[S13:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_10]])
-// CHECK:        %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
-// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:          %[[S19:.*]] = linalg.matmul
-// CHECK:          %[[S22:.*]] = linalg.matmul
+// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK:          %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK:          %[[S19:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_11]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK:          %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:          %[[S21:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[S22:.*]] = linalg.matmul ins(%[[S19]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK:          %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
 // CHECK:          %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S23]] : tensor<4x4xf32>) {
 // CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:            linalg.yield %[[IN]] : f32
 // CHECK:          } -> tensor<4x4xf32>
 // CHECK:          %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_12]]
-// CHECK:        scf.yield %[[S16]] : tensor<2x4x4x2xf32>
-// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:      %[[S15:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, %[[S14]], %[[S15]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
+// CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK:      %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
-// CHECK:    scf.yield %[[S10]]
-// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S9]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK:    scf.yield %[[S9]]
+// CHECK:  %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
 // CHECK:  return %[[EXTRACTED_SLICE]]
 
 // -----
@@ -255,15 +293,21 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func.func @conv2d_mx1_rx1
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
 // CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
 // CHECK:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK:   %[[C2:.*]] = arith.constant 2 : index
 // CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
 // CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
 // CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
-// CHECK:       %[[S10:.*]] = linalg.matmul
+// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
 // CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
@@ -271,18 +315,24 @@ module attributes {transform.with_named_sequence} {
 // CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
 // CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
-// CHECK:       %[[S10:.*]] = linalg.matmul
+// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
 // CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
-// CHECK:   %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
-// CHECK:   %[[S6:.*]] = linalg.batch_matmul
+// CHECK:   %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:   %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK:   %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_3]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
 // CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
-// CHECK:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
-// CHECK:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
-// CHECK:       %[[S11:.*]] = linalg.matmul
+// CHECK:       %[[S9:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK:       %[[S10:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S9]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:       %[[S11:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
 // CHECK:       %[[S12:.*]] = tensor.empty() : tensor<4x1xf32>
 // CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x1xf32>) {
 // CHECK:       ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
@@ -291,5 +341,5 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
 // CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
-// CHECK:     scf.yield %[[S8]]
-// CHECK:   return %[[S7]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   return %[[S6]]
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 2ffd9fd9c0db21..4369f5f1eab4ca 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -45,7 +45,7 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-DAG:   %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
 // CHECK-NEXT:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) {
 // CHECK-NEXT:       %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], %[[C0]], %[[C0]], %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<3x3xf32>
 // CHECK-NEXT:       %[[S9:.*]] = tensor.empty() : tensor<6x3xf32>
 // CHECK-NEXT:       %[[S10:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S9]] : tensor<6x3xf32>) -> tensor<6x3xf32>
@@ -56,7 +56,7 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32>
 // CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
 // CHECK-NEXT:     }
-// CHECK-NEXT:     scf.yield %[[S8]] : tensor<6x6x5x2xf32>
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x5x2xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
 // CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
@@ -64,12 +64,12 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:   } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
 // CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:         %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK-NEXT:           %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK-NEXT:           %[[S12:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S11]], %[[S12]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK-NEXT:           %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S10]], %[[S11]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32>
 // CHECK-NEXT:           %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
 // CHECK-NEXT:           %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
 // CHECK-NEXT:           %[[S15:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
@@ -79,11 +79,11 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S18]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:         }
-// CHECK-NEXT:         scf.yield %[[S10]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:         scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:       }
-// CHECK-NEXT:       scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:       scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:     }
-// CHECK-NEXT:     scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:   }
 // CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
 // CHECK-NEXT:   %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
@@ -95,10 +95,10 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:   ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index):
 // CHECK-NEXT:     tensor.yield %[[CST_6]] : f32
 // CHECK-NEXT:   } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK-NEXT:         %[[S10:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
 // CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32>
 // CHECK-NEXT:           %[[S11:.*]] = tensor.empty() : tensor<4x6xf32>
 // CHECK-NEXT:           %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<4x6xf32>) -> tensor<4x6xf32>
@@ -117,12 +117,12 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
 // CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:         }
-// CHECK-NEXT:         scf.yield %[[S10]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:         scf.yield %[[S9]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:       }
-// CHECK-NEXT:       scf.yield %[[S9]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:       scf.yield %[[S8]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:     }
-// CHECK-NEXT:     scf.yield %[[S8]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:     scf.yield %[[S7]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:   }
-// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S7]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
index 1186bf8fe5aced..0040d81a2d24e7 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -7,19 +7,19 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_4x4_3x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT:  %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT:  return %[[S7]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:  %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:  %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT:  return %[[S9]] : tensor<2x4x4x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -31,19 +31,19 @@ func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>
 
 // CHECK-LABEL: func.func @conv2d_2x2_5x5
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-// CHECK-NEXT:   return %[[S7]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT:   return %[[S9]] : tensor<2x2x2x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -56,18 +56,18 @@ func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>
 // CHECK-LABEL: func.func @conv2d_1x4_1x3
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
 // CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S0]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S2]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-// CHECK-NEXT:   return %[[S7]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT:   return %[[S9]] : tensor<2x1x4x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -80,18 +80,18 @@ func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
 // CHECK-LABEL: func.func @conv2d_4x1_3x1
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
 // CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK-NEXT:   %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S0]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
-// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S2]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
-// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK-NEXT:   %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-// CHECK-NEXT:   return %[[S7]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S7]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT:   return %[[S9]] : tensor<2x4x1x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -104,18 +104,18 @@ func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf3
 // CHECK-LABEL: func.func @conv2d_aligned
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
 // CHECK-NEXT:  %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT:  %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
-// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S2]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
-// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
-// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x8x2xf32>
-// CHECK-NEXT:  %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
-// CHECK-NEXT:  %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
-// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK-NEXT:  %[[S7:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT:  return %[[S7]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT:  %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[S7:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[S8:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S7]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S8]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT:  %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:  return %[[S9]] : tensor<2x8x8x2xf32>
 // CHECK-NEXT: }
 
 // -----
@@ -127,8 +127,8 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 
 // CHECK-LABEL: func.func @conv2d_unaligned
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-NEXT:  %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT:  %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
 // CHECK-NEXT:  %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
 // CHECK-NEXT:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
 // CHECK-NEXT:  ^bb0
@@ -137,7 +137,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
 // CHECK-NEXT:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:  %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
 // CHECK-NEXT:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT:  %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
 // CHECK-NEXT:  %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
 // CHECK-NEXT:  %[[S5:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
 // CHECK-NEXT:  %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S5]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>

>From a4afe577a93ba839e8c6be65765761bd68a4dbce Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at celest.fr>
Date: Fri, 13 Sep 2024 15:46:57 +0100
Subject: [PATCH 3/3] =?UTF-8?q?Fix=20use=20of=20hardcoded=20SSA=C2=A0value?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Max191 <44243577+Max191 at users.noreply.github.com>
---
 .../Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir     | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 78d2e49cf6bb4b..c5760acf94a88a 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -56,7 +56,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[S12:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S11]] : tensor<6x3xf32>) -> tensor<6x3xf32>
 // CHECK:      %[[S13:.*]] = tensor.empty() : tensor<6x6xf32>
 // CHECK:      %[[S14:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S13]] : tensor<6x6xf32>) -> tensor<6x6xf32>
-// CHECK:      %[[S15:.*]] = linalg.matmul ins(%[[S12]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%14 : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK:      %[[S15:.*]] = linalg.matmul ins(%[[S12]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32>
 // CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]



More information about the Mlir-commits mailing list