[Mlir-commits] [mlir] [mlir][linalg] Fix for bias handling for Winograd (PR #110331)

Dmitriy Smirnov llvmlistbot at llvm.org
Fri Oct 4 06:23:01 PDT 2024


https://github.com/d-smirnov updated https://github.com/llvm/llvm-project/pull/110331

>From f480e97330ccb31856489dd67226ad3f1660001f Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Fri, 27 Sep 2024 19:30:51 +0100
Subject: [PATCH 1/3] [mlir][linalg] Fix for bias handling for Winograd

Patch adds handing of bias to Winograd output transform op decompositon

Signed-off-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
 .../Linalg/Transforms/WinogradConv2D.cpp        | 17 ++++++++++++++++-
 .../transform-tile-and-winograd-rewrite.mlir    | 15 ++++++++++++---
 .../Dialect/Linalg/winograd-conv2d-rewrite.mlir |  5 ++++-
 3 files changed, 32 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 80edf4a32c6df8..06f0aebbd2d559 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -837,9 +837,24 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
     Value widthOffset =
         builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
 
+    Value outInitVal =
+        extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
+                            widthOffset, retRows, retCols,
+                            /*loopNorFIdx=*/0,
+                            /*loopCorFIdx=*/3, /*heightIdx=*/1,
+                            /*widthIdx=*/2);
+    Value outVal =
+        builder
+            .create<linalg::AddOp>(
+                loc, outInitVal.getType(), ValueRange{matmulRetValue, outInitVal},
+                ValueRange{builder.create<tensor::EmptyOp>(
+                    loc, llvm::cast<ShapedType>(outInitVal.getType()).getShape(),
+                    elementType)})
+            .getResult(0);
+
     // Insert (H, W) to (N, H, W, F).
     Value combinedVal =
-        insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
+        insert2DDataTo4D(builder, loc, outVal, args[0], NIter, FIter,
                          heightOffset, widthOffset, retRows, retCols,
                          /*loopNorFIdx=*/0,
                          /*loopCorFIdx=*/3, /*heightIdx=*/1,
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index c5760acf94a88a..01c0d0a826c999 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -109,7 +109,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK:            linalg.yield %[[IN]] : f32
 // CHECK:          } -> tensor<4x4xf32>
 // CHECK:          %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[S25:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[S26:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:          %[[S27:.*]] = linalg.add ins(%[[S24]], %[[S25]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S26]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S27]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_9]]
 // CHECK:        scf.yield %[[S15]]
 // CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
@@ -243,7 +246,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK:            linalg.yield %[[IN]] : f32
 // CHECK:          } -> tensor<4x4xf32>
 // CHECK:          %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[S26:.*]] = tensor.extract_slice %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[S27:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK:          %[[S28:.*]] = linalg.add ins(%[[S25]], %[[S26]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S27]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S28]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_12]]
 // CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
 // CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
@@ -339,7 +345,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK:         linalg.yield %[[IN]] : f32
 // CHECK:       } -> tensor<4x1xf32>
 // CHECK:       %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S15:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S16:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK:       %[[S17:.*]] = linalg.add ins(%[[S14]], %[[S15]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S16]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S17]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   return %[[S6]]
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 4369f5f1eab4ca..b24a93bc6c27ee 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -114,7 +114,10 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:           %[[S19:.*]] = linalg.mul ins(%[[S18]], %[[S16]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S17]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK-NEXT:           %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK-NEXT:           %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:           %[[S22:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<4x4xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S24:.*]] = linalg.add ins(%[[S19]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
 // CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:         }
 // CHECK-NEXT:         scf.yield %[[S9]] : tensor<2x12x12x2xf32>

>From e081ff9508bf5dc1dc5b0de93c5aaba6b97d1eff Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Mon, 30 Sep 2024 10:53:58 +0100
Subject: [PATCH 2/3] [mlir][linalg] Adds destinationStyleOpInterface

Tagged winograd.output_transform op with destinationStyleOpInterface
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  3 ++-
 .../transform-tile-and-winograd-rewrite.mlir  | 12 ++++-----
 .../Linalg/transform-tile-winograd.mlir       | 26 +++++++++----------
 3 files changed, 21 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 5b6a90f806bedd..e42fd5d2ce13c1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -313,7 +313,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
 }
 
 def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
-    [AllElementTypesMatch<["value", "output"]>,
+    [AllElementTypesMatch<["value", "output"]>, DestinationStyleOpInterface,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
@@ -396,6 +396,7 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
     int64_t getOutputFDim() {
       return 3;
     }
+    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 01c0d0a826c999..13c0d49d9b7f50 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -85,15 +85,15 @@ module attributes {transform.with_named_sequence} {
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
 // CHECK:  %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S7:.*]] = tensor.empty()
 // CHECK:  %[[S6:.*]] = linalg.batch_matmul
 // CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
-// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
 // CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
 // CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
@@ -221,16 +221,16 @@ module attributes {transform.with_named_sequence} {
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
 // CHECK:  %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:  %[[S7:.*]] = tensor.empty()
 // CHECK:  %[[S6:.*]] = linalg.batch_matmul
 // CHECK:  %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
 // CHECK:  %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
-// CHECK:  %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK:  %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[PADDED_8]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
 // CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
 // CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
 // CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index 21522a2083b463..9598c434aadb8f 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -279,14 +279,14 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C2_1:.*]] = arith.constant 2 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1]]) -> (tensor<2x8x8x2xf32>)
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] iter_args(%[[ARG6:.*]] = %[[ARG5]]) -> (tensor<2x8x8x2xf32>)
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
 // CHECK:       %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:       %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:       %[[S5:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:       %[[S6:.*]] = affine.apply #[[$MAP1]]()
-// CHECK:       %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
+// CHECK:       %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
 
 // -----
 
@@ -321,10 +321,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:    %[[C2_3:.*]] = arith.constant 2 : index
 // CHECK-DAG:    %[[C2_5:.*]] = arith.constant 2 : index
 // CHECK-DAG:    %[[C2_7:.*]] = arith.constant 2 : index
-// CHECK:    %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
-// CHECK:      %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
-// CHECK:        %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]]
-// CHECK:          %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]]
+// CHECK:    %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]] iter_args(%[[ARG9:.*]] = %[[ARG1]]) -> (tensor<3x8x8x5xf32>)
+// CHECK:      %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]] iter_args(%[[ARG10:.*]] = %[[ARG9]]) -> (tensor<3x8x8x5xf32>)
+// CHECK:        %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]] iter_args(%[[ARG11:.*]] = %[[ARG10]])
+// CHECK:          %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]] iter_args(%[[ARG12:.*]] = %[[ARG11]])
 // CHECK:            %[[C3_8:.*]] = arith.constant 3 : index
 // CHECK:            %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]])
 // CHECK:            %[[C5_9:.*]] = arith.constant 5 : index
@@ -334,7 +334,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK:            %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]])
 // CHECK:            %[[S9:.*]] = affine.apply #[[$MAP3]]()
 // CHECK:            %[[S10:.*]] = affine.apply #[[$MAP3]]()
-// CHECK:            %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
+// CHECK:            %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
 
 // -----
 
@@ -367,14 +367,14 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_4:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_6:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]]
-// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]]
-// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]]
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[ARG1]]) -> (tensor<3x8x1x5xf32>)
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]] iter_args(%[[ARG10:.*]] = %[[ARG9]]) -> (tensor<3x8x1x5xf32>)
+// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]] iter_args(%[[ARG11:.*]] = %[[ARG10]]) -> (tensor<3x8x1x5xf32>)
+// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]] iter_args(%[[ARG12:.*]] = %[[ARG11]]) -> (tensor<3x8x1x5xf32>)
 // CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
 // CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:           %[[S7:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:           %[[S8:.*]] = affine.apply #[[$MAP1]]()
-// CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
+// CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
 // CHECK:           %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>

>From b153433de9bee0432f5a8851ff3907fed40e3c62 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Fri, 4 Oct 2024 14:09:13 +0100
Subject: [PATCH 3/3] Addressed comments

---
 .../Linalg/Transforms/WinogradConv2D.cpp      | 131 ++++++++----------
 .../transform-tile-and-winograd-rewrite.mlir  |  48 +++----
 .../Linalg/winograd-conv2d-rewrite.mlir       |  20 ++-
 3 files changed, 86 insertions(+), 113 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 06f0aebbd2d559..79f77822116fd7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -729,6 +729,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
 
   auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
                        ValueRange args) -> scf::ValueVector {
+    auto context = builder.getContext();
     Value tileHIter = ivs[0];
     Value tileWIter = ivs[1];
     Value NIter = ivs[2];
@@ -740,29 +741,41 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
                             FIter, 2, 3, /*loopNorFIdx=*/4,
                             /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
 
-    TransformMapKeyTy key = {m, r};
-    int64_t retRows = 1;
-    int64_t retCols = 1;
-    int64_t leftScalarFactor = 1;
-    int64_t rightScalarFactor = 1;
+    const TransformMapKeyTy key = {m, r};
+    const TransformMatrix &AMatrix = AMatrices.at(key);
+    const TransformMatrix &ATMatrix = ATMatrices.at(key);
+    int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
+                           (leftTransform ? ATMatrix.scalarFactor : 1);
+    int64_t retCols = rightTransform ? AMatrix.cols : 1;
+    int64_t retRows = leftTransform ? ATMatrix.rows : 1;
+
     Value matmulRetValue = extractValue;
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
-    if (leftTransform) {
-      // Get constant transform matrix AT.
-      auto it = ATMatrices.find(key);
-      if (it == ATMatrices.end())
-        return {};
-      const TransformMatrix &ATMatrix = it->second;
 
-      leftScalarFactor = ATMatrix.scalarFactor;
-      retRows = ATMatrix.rows;
+    auto affineMap =
+        AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+    Value heightOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
+    Value widthOffset =
+        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+
+    Value outInitVal =
+        extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
+                            widthOffset, retRows, retCols,
+                            /*loopNorFIdx=*/0,
+                            /*loopCorFIdx=*/3, /*heightIdx=*/1,
+                            /*widthIdx=*/2);
+    if (leftTransform) {
       auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
-      auto empty =
-          builder
-              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
-              .getResult();
-      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+      Value init = outInitVal;
+      if (rightTransform || scalarFactor != 1) {
+        auto empty = builder
+                         .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType)
+                         .getResult();
+        init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+      }
 
       Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
       // Multiply AT x m.
@@ -772,21 +785,16 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
     }
 
     if (rightTransform) {
-      // Get constant transform matrix T.
-      auto it = AMatrices.find(key);
-      if (it == AMatrices.end())
-        return {};
-      const TransformMatrix &AMatrix = it->second;
-
-      rightScalarFactor = AMatrix.scalarFactor;
       auto matmulType =
           RankedTensorType::get({retRows, AMatrix.cols}, elementType);
-      retCols = AMatrix.cols;
-      auto empty =
-          builder
-              .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
-              .getResult();
-      auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+      Value init = outInitVal;
+      if (scalarFactor != 1) {
+        auto empty = builder
+                         .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                  elementType)
+                         .getResult();
+        init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
+      }
 
       Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
       // Multiply y = (AT x m) x A.
@@ -795,66 +803,39 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
       matmulRetValue = matmulOp.getResult(0);
     }
 
-    if (leftScalarFactor * rightScalarFactor != 1) {
-      // Multiply scalar factor.
-      Value scalarFactor = builder.create<arith::ConstantOp>(
-          loc,
-          FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+    if (scalarFactor != 1) {
+      // Multiply by scalar factor and add outInitVal.
+      Value scalarFactorValue = builder.create<arith::ConstantOp>(
+          loc, FloatAttr::get(elementType, scalarFactor));
       auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
-      auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
-                                                  elementType);
-
       auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
       SmallVector<AffineMap> affineMaps = {
-          AffineMap::get(2, 0, init.getContext()), identityAffineMap};
-      auto broadcastedScalar =
+          AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
+
+      matmulRetValue =
           rewriter
               .create<linalg::GenericOp>(
-                  loc, matmulType, ValueRange{scalarFactor}, ValueRange{init},
-                  affineMaps,
+                  loc, matmulType,
+                  ValueRange{scalarFactorValue, matmulRetValue},
+                  ValueRange{outInitVal}, affineMaps,
                   llvm::ArrayRef<utils::IteratorType>{
                       utils::IteratorType::parallel,
                       utils::IteratorType::parallel},
                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
                       ValueRange args) {
-                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+                    auto mulf = nestedBuilder.create<arith::MulFOp>(
+                        nestedLoc, args[0], args[1]);
+                    auto addf = nestedBuilder.create<arith::AddFOp>(
+                        nestedLoc, mulf.getResult(), args[2]);
+                    nestedBuilder.create<linalg::YieldOp>(nestedLoc,
+                                                          addf.getResult());
                   })
               .getResult(0);
-
-      matmulRetValue = builder
-                           .create<linalg::MulOp>(
-                               loc, matmulType,
-                               ValueRange{broadcastedScalar, matmulRetValue},
-                               ValueRange{init})
-                           .getResult(0);
     }
 
-    auto context = builder.getContext();
-    auto affineMap =
-        AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-    Value heightOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
-    Value widthOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
-
-    Value outInitVal =
-        extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
-                            widthOffset, retRows, retCols,
-                            /*loopNorFIdx=*/0,
-                            /*loopCorFIdx=*/3, /*heightIdx=*/1,
-                            /*widthIdx=*/2);
-    Value outVal =
-        builder
-            .create<linalg::AddOp>(
-                loc, outInitVal.getType(), ValueRange{matmulRetValue, outInitVal},
-                ValueRange{builder.create<tensor::EmptyOp>(
-                    loc, llvm::cast<ShapedType>(outInitVal.getType()).getShape(),
-                    elementType)})
-            .getResult(0);
-
     // Insert (H, W) to (N, H, W, F).
     Value combinedVal =
-        insert2DDataTo4D(builder, loc, outVal, args[0], NIter, FIter,
+        insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
                          heightOffset, widthOffset, retRows, retCols,
                          /*loopNorFIdx=*/0,
                          /*loopCorFIdx=*/3, /*heightIdx=*/1,
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 13c0d49d9b7f50..776dc5b748c846 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -97,22 +97,20 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
 // CHECK:        %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
 // CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S25:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          %[[S16:.*]] = tensor.empty() : tensor<4x6xf32>
 // CHECK:          %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<4x6xf32>) -> tensor<4x6xf32>
 // CHECK:          %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_8]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
 // CHECK:          %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
 // CHECK:          %[[S20:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK:          %[[S21:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[S22:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK:          %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
-// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK:            linalg.yield %[[IN]] : f32
+// CHECK:          %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S21]] : f32, tensor<4x4xf32>) outs(%[[S25]] : tensor<4x4xf32>) {
+// CHECK:          ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:             %[[VAL_90:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
+// CHECK:             %[[VAL_91:.*]] = arith.addf %[[VAL_90]], %[[OUT]] : f32
+/// CHECK:            linalg.yield %[[VAL_91]] : f32
 // CHECK:          } -> tensor<4x4xf32>
-// CHECK:          %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[S25:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK:          %[[S26:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK:          %[[S27:.*]] = linalg.add ins(%[[S24]], %[[S25]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S26]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S27]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S23]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_9]]
 // CHECK:        scf.yield %[[S15]]
 // CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
@@ -234,22 +232,20 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
 // CHECK:        %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
 // CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:          %[[S26:.*]] = tensor.extract_slice %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
 // CHECK:          %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
 // CHECK:          %[[S19:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_11]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<4x6xf32>) -> tensor<4x6xf32>
 // CHECK:          %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
 // CHECK:          %[[S21:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK:          %[[S22:.*]] = linalg.matmul ins(%[[S19]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK:          %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S23]] : tensor<4x4xf32>) {
-// CHECK:          ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK:            linalg.yield %[[IN]] : f32
+// CHECK:          %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S22]] : f32, tensor<4x4xf32>) outs(%[[S26]] : tensor<4x4xf32>) {
+// CHECK:          ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:             %[[VAL_104:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
+// CHECK:             %[[VAL_105:.*]] = arith.addf %[[VAL_104]], %[[OUT]] : f32
+/// CHECK:            linalg.yield %[[VAL_105]] : f32
 // CHECK:          } -> tensor<4x4xf32>
-// CHECK:          %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[S26:.*]] = tensor.extract_slice %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
-// CHECK:          %[[S27:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK:          %[[S28:.*]] = linalg.add ins(%[[S25]], %[[S26]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S27]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S28]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK:          %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S24]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
 // CHECK:          scf.yield %[[INSERTED_SLICE_12]]
 // CHECK:        scf.yield %[[S15]] : tensor<2x4x4x2xf32>
 // CHECK:      %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
@@ -336,19 +332,17 @@ module attributes {transform.with_named_sequence} {
 // CHECK:   %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
 // CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       %[[S15:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
 // CHECK:       %[[S9:.*]] = tensor.empty() : tensor<4x1xf32>
 // CHECK:       %[[S10:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S9]] : tensor<4x1xf32>) -> tensor<4x1xf32>
 // CHECK:       %[[S11:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK:       %[[S12:.*]] = tensor.empty() : tensor<4x1xf32>
-// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x1xf32>) {
-// CHECK:       ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK:         linalg.yield %[[IN]] : f32
+// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S11]] : f32, tensor<4x1xf32>) outs(%[[S15]] : tensor<4x1xf32>) {
+// CHECK:       ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:          %[[VAL_57:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
+// CHECK:          %[[VAL_58:.*]] = arith.addf %[[VAL_57]], %[[OUT]] : f32
+/// CHECK:         linalg.yield %[[VAL_58]] : f32
 // CHECK:       } -> tensor<4x1xf32>
-// CHECK:       %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK:       %[[S15:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
-// CHECK:       %[[S16:.*]] = tensor.empty() : tensor<4x1xf32>
-// CHECK:       %[[S17:.*]] = linalg.add ins(%[[S14]], %[[S15]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S16]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S17]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   return %[[S6]]
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index b24a93bc6c27ee..16d06a74732729 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -100,24 +100,22 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:       %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) {
 // CHECK-NEXT:         %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) {
 // CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK-NEXT:           %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK-NEXT:           %[[S22:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<4x4xf32>
 // CHECK-NEXT:           %[[S11:.*]] = tensor.empty() : tensor<4x6xf32>
 // CHECK-NEXT:           %[[S12:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S11]] : tensor<4x6xf32>) -> tensor<4x6xf32>
 // CHECK-NEXT:           %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<4x6xf32>) -> tensor<4x6xf32>
 // CHECK-NEXT:           %[[S14:.*]] = tensor.empty() : tensor<4x4xf32>
 // CHECK-NEXT:           %[[S15:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32>
 // CHECK-NEXT:           %[[S16:.*]] = linalg.matmul ins(%[[S13]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S15]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:           %[[S17:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:           %[[S18:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S17]] : tensor<4x4xf32>) {
-// CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT:             linalg.yield %[[IN]] : f32
+// CHECK-NEXT:           %[[S18:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S16]] : f32, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:             %[[VAL_98:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT:             %[[VAL_99:.*]] = arith.addf %[[VAL_98]], %[[OUT]] : f32
+// CHECK-NEXT:             linalg.yield %[[VAL_99]] : f32
 // CHECK-NEXT:           } -> tensor<4x4xf32>
-// CHECK-NEXT:           %[[S19:.*]] = linalg.mul ins(%[[S18]], %[[S16]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S17]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:           %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK-NEXT:           %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK-NEXT:           %[[S22:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<4x4xf32>
-// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
-// CHECK-NEXT:           %[[S24:.*]] = linalg.add ins(%[[S19]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
-// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S18]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
 // CHECK-NEXT:           scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
 // CHECK-NEXT:         }
 // CHECK-NEXT:         scf.yield %[[S9]] : tensor<2x12x12x2xf32>



More information about the Mlir-commits mailing list