[Mlir-commits] [mlir] [MLIR][Linalg] Fixes for Winograd decomposition and for tiling (PR #123675)

Dmitriy Smirnov llvmlistbot at llvm.org
Wed Jan 22 13:52:55 PST 2025


https://github.com/d-smirnov updated https://github.com/llvm/llvm-project/pull/123675

>From c26282847e07b3732207853c40909c995a938d3a Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Fri, 10 Jan 2025 11:12:00 +0000
Subject: [PATCH 1/2] [MLIR][Linalg] Fixes for Winograd decomposition and for
 tiling

The PR addresses issues with filers 1 x r and r x 1 and with tiling

Signed-off-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |   6 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  36 +++---
 .../Linalg/Transforms/WinogradConv2D.cpp      |  34 +++---
 .../transform-tile-and-winograd-rewrite.mlir  | 113 +++++++++++++++++-
 .../Linalg/transform-tile-winograd.mlir       |  62 +++++-----
 5 files changed, 186 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index e42fd5d2ce13c1..f8df828f74851b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -155,7 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 }
 
 def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
-    [AllElementTypesMatch<["filter", "output"]>,
+    [AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
@@ -220,12 +220,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
     int64_t getFilterCDim() {
       return 3;
     }
+    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
   }];
   let hasVerifier = 1;
 }
 
 def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
-    [AllElementTypesMatch<["input", "output"]>,
+    [AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
@@ -308,6 +309,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
     int64_t getOutputCDim() {
       return 5;
     }
+    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..0649fbc8e9549b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3063,8 +3063,11 @@ LogicalResult WinogradInputTransformOp::verify() {
   int m = getM();
   int r = getR();
   int64_t tileSize = m + r - 1;
-  bool leftTransform = inputH != 1;
-  bool rightTransform = inputW != 1;
+
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
+  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
 
   SmallVector<int64_t> expectedOutputShape(6, inputH);
   if (ShapedType::isDynamic(inputH)) {
@@ -3073,7 +3076,7 @@ LogicalResult WinogradInputTransformOp::verify() {
   } else {
     expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
     expectedOutputShape[getOutputTileHDim()] =
-        leftTransform ? (inputH - (r - 1)) / m : 1;
+        leftTransform ? (inputH - (r - 1)) / m : inputH;
   }
   if (ShapedType::isDynamic(inputW)) {
     expectedOutputShape[getOutputAlphaWDim()] = tileSize;
@@ -3081,13 +3084,11 @@ LogicalResult WinogradInputTransformOp::verify() {
   } else {
     expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
     expectedOutputShape[getOutputTileWDim()] =
-        rightTransform ? (inputW - (r - 1)) / m : 1;
+        rightTransform ? (inputW - (r - 1)) / m : inputW;
   }
   expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
   expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
 
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  ArrayRef<int64_t> outputShape = outputType.getShape();
   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
     return emitOpError("the output shape is not expected");
   }
@@ -3124,15 +3125,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  ShapedType inputType = getInputOperandType();
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputH = inputShape[getInputHDim()];
-  int64_t inputW = inputShape[getInputWDim()];
+  ShapedType outputType = getOutputOperandType();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
+  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
+
   int64_t m = getM();
   int64_t r = getR();
   int64_t alpha = m + r - 1;
-  int64_t alphaH = inputH != 1 ? alpha : 1;
-  int64_t alphaW = inputW != 1 ? alpha : 1;
+  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
+  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
+
   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
@@ -3165,6 +3168,11 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   int64_t m = getM();
   int64_t r = getR();
 
+  ShapedType outputType = getOutputOperandType();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t alphaH = outputShape[getOutputAlphaHDim()];
+  int64_t alphaW = outputShape[getOutputAlphaWDim()];
+
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
   auto offsetAffineMap =
@@ -3190,9 +3198,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   sliceOffsets.append(
       {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
   OpFoldResult sizeH =
-      inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+      alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
   OpFoldResult sizeW =
-      inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+      alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
   sliceSizes.append(
       {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
   int64_t inputRank = getInputOperandRank();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 79f77822116fd7..f1059ddf5da2cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -514,12 +514,14 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
     Value CIter = ivs[3];
 
     auto context = builder.getContext();
+
+    auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
     auto affineMap =
         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-    Value heightOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
-    Value widthOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+    Value heightOffset = builder.create<affine::AffineApplyOp>(
+        loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+    Value widthOffset = builder.create<affine::AffineApplyOp>(
+        loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
 
     // Extract (H, W) from (N, H, W, C).
     auto extractInput =
@@ -753,12 +755,13 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
 
+    auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
     auto affineMap =
         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-    Value heightOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
-    Value widthOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+    Value heightOffset = builder.create<affine::AffineApplyOp>(
+        loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+    Value widthOffset = builder.create<affine::AffineApplyOp>(
+        loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
 
     Value outInitVal =
         extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -1075,16 +1078,17 @@ FailureOr<Operation *>
 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
                                       linalg::WinogradInputTransformOp op) {
   Location loc = op.getLoc();
-  Value input = op.getInput();
-  auto inputType = cast<ShapedType>(input.getType());
-  auto inputShape = inputType.getShape();
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
+  Value output = op.getOutput();
+  auto outputType = cast<ShapedType>(output.getType());
+  auto outputShape = outputType.getShape();
+
+  int64_t outputH = outputShape[0];
+  int64_t outputW = outputShape[1];
 
   // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = inputH != 1;
+  bool leftTransform = outputH != 1;
   // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = inputW != 1;
+  bool rightTransform = outputW != 1;
   Value transformedInput =
       inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
                      op.getR(), leftTransform, rightTransform);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 776dc5b748c846..a9af874ea40933 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -61,13 +61,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
 // CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -195,13 +194,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
 // CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
 // CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
 // CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -346,3 +344,106 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   return %[[S6]]
+
+// -----
+
+func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x1x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
+  %4 = tensor.empty() : tensor<6x4x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
+  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+  return %7 : tensor<2x4x2x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x2x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK:     %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %8, 0] [2, 6, 1, 5] [1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:     %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK:     %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S11:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S12:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S11]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_6]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S12]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[INSERTED_SLICE_7:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE_7]]
+// CHECK:     scf.yield %[[S10]]
+// CHECK:    %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:    scf.yield %[[INSERTED_SLICE]]
+// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:   %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:   %[[S4:.*]] = tensor.empty() : tensor<6x4x2xf32>
+// CHECK:   %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_4]], %[[COLLAPSED]] : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2]
+// CHECK:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG8]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK:       %[[S11:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:       %[[S12:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE_6]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S11]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S12]] : f32, tensor<4x1xf32>) outs(%[[EXTRACTED_SLICE_7]] : tensor<4x1xf32>) {
+// CHECK:       ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:          %[[VAL_57:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
+// CHECK:          %[[VAL_58:.*]] = arith.addf %[[VAL_57]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[VAL_58]] : f32
+// CHECK:       } -> tensor<4x1xf32>
+// CHECK:       %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE_8]]
+// CHECK:     scf.yield %[[S9]]
+// CHECK:     %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S8]] into %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK:     scf.yield %[[INSERTED_SLICE]]
+// CHECK:   return %[[S7]]
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index 9598c434aadb8f..336b69160dbe87 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -21,11 +21,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[C1_1:.*]] = arith.constant 1 : index
-// CHECK:  %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:    %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK:  %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[ARG1]])
+// CHECK:    %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK:      %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
+// CHECK:      %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
 // CHECK:      %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
 
 // -----
 
@@ -51,14 +52,14 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C2_1:.*]] = arith.constant 2 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]]
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[ARG1]])
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
 // CHECK:       %[[C5_2:.*]] = arith.constant 5 : index
 // CHECK:       %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
-// CHECK:       %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
+// CHECK:       %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
 // CHECK:       %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
-
+// CHECK:      %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S4]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x?x1xf32> into tensor<6x6x5x2xf32>
 // -----
 
 func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
@@ -82,11 +83,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_1:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[ARG1]])
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
 // CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
-// CHECK:       %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
+// CHECK:       %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
 // CHECK:       %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S3]] into %[[ARG5]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
 
 // -----
 
@@ -113,15 +115,16 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
-// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[ARG1]])
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
 // CHECK:   %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:   %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:   %[[S5:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:   %[[S6:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
-// CHECK:   %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK:   %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
 // CHECK:   %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK:   %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S7]] into %[[ARG5]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
 
 // -----
 
@@ -154,17 +157,18 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK-DAG: %[[C1_5:.*]] = arith.constant 1 : index
 // CHECK-DAG: %[[C1_7:.*]] = arith.constant 1 : index
-// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
-// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
-// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
-// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[ARG1]])
+// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
+// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]] iter_args(%[[ARG9:.*]] = %[[ARG7]])
 // CHECK:         %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:         %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:         %[[S8:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
-// CHECK:         %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
+// CHECK:         %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
 // CHECK:         %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x2x2x2x5xf32>
 
 // -----
 
@@ -198,18 +202,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
 // CHECK-DAG: %[[C2_6:.*]] = arith.constant 2 : index
 // CHECK-DAG: %[[C2_8:.*]] = arith.constant 2 : index
-// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
-// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
-// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]]
-// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]]
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]] iter_args(%[[ARG3:.*]] = %[[ARG1]])
+// CHECK:   %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
+// CHECK:     %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:       %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]] iter_args(%[[ARG9:.*]] = %[[ARG7]])
 // CHECK:         %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]])
 // CHECK:         %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG2]])
 // CHECK:         %[[S7:.*]] = affine.apply #[[$MAP1]](%[[ARG4]])
 // CHECK:         %[[S8:.*]] = affine.apply #[[$MAP2]]()
 // CHECK:         %[[S9:.*]] = affine.apply #[[$MAP2]]()
 // CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
-// CHECK:         %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
+// CHECK:         %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
 // CHECK:         %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x?xf32> into tensor<6x6x2x2x2x5xf32>
 
 // -----
 
@@ -242,17 +247,18 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C1_2:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C1_7:.*]] = arith.constant 1 : index
-// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]]
-// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]]
-// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
-// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]] iter_args(%[[ARG3:.*]] = %[[ARG1]])
+// CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
+// CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]] iter_args(%[[ARG9:.*]] = %[[ARG7]])
 // CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
 // CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:           %[[S7:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:           %[[S8:.*]] = affine.apply #[[$MAP1]]()
 // CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], 0, %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
-// CHECK:           %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
+// CHECK:           %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
 // CHECK:           %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x1x1x1xf32> into tensor<1x6x1x2x2x5xf32>
 
 // -----
 

>From 0a0fa156244d15ba3360ec207d79c64bf51dc6b7 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Wed, 22 Jan 2025 12:14:36 +0000
Subject: [PATCH 2/2] bug fixes

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 42 +++++++++----------
 .../transform-tile-and-winograd-rewrite.mlir  | 16 ++++---
 .../Linalg/transform-tile-winograd.mlir       | 30 ++++++-------
 3 files changed, 43 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0649fbc8e9549b..d408cc4b16daf6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3160,11 +3160,6 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
-  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  ShapedType inputType = getInputOperandType();
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputH = inputShape[getInputHDim()];
-  int64_t inputW = inputShape[getInputWDim()];
   int64_t m = getM();
   int64_t r = getR();
 
@@ -3175,12 +3170,16 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
+  auto identityAffineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
   auto offsetAffineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
   Value mappedOffsetH = affine::makeComposedAffineApply(
-      builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
+      builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
+      offsets[getOutputTileHDim()]);
   Value mappedOffsetW = affine::makeComposedAffineApply(
-      builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
+      builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
+      offsets[getOutputTileWDim()]);
   auto sizeAffineMap = AffineMap::get(
       1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
   Value mappedSizeH = affine::makeComposedAffineApply(
@@ -3191,10 +3190,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   SmallVector<Value> tiledOperands;
   SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
-  OpFoldResult offsetH =
-      inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
-  OpFoldResult offsetW =
-      inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
+  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
   sliceOffsets.append(
       {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
   OpFoldResult sizeH =
@@ -3308,28 +3305,29 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
+  auto identityAffineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
   auto affineMap =
       AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
 
+  ShapedType valueType = getValueOperandType();
+  ArrayRef<int64_t> valueShape = valueType.getShape();
+  int64_t valueH = valueShape[0];
+  int64_t valueW = valueShape[1];
   Value mappedOffsetH = affine::makeComposedAffineApply(
-      builder, loc, affineMap, offsets[getValueTileHDim()]);
+      builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
+      offsets[getValueTileHDim()]);
   Value mappedOffsetW = affine::makeComposedAffineApply(
-      builder, loc, affineMap, offsets[getValueTileWDim()]);
+      builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
+      offsets[getValueTileWDim()]);
   Value mappedSizeH = affine::makeComposedAffineApply(
       builder, loc, affineMap, sizes[getValueTileHDim()]);
   Value mappedSizeW = affine::makeComposedAffineApply(
       builder, loc, affineMap, sizes[getValueTileWDim()]);
 
-  ShapedType valueType = getValueOperandType();
-  ArrayRef<int64_t> valueShape = valueType.getShape();
-  int64_t valueH = valueShape[0];
-  int64_t valueW = valueShape[1];
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
-  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  OpFoldResult offsetH =
-      valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
-  OpFoldResult offsetW =
-      valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
+  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
   OpFoldResult sizeH =
       valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
   OpFoldResult sizeW =
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index a9af874ea40933..cdc4b8a72a276d 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -379,10 +379,9 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_mx1_rx1_2
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x2x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
 // CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
 // CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
@@ -405,8 +404,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x2x2x5xf32>
 // CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
-// CHECK:     %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %8, 0] [2, 6, 1, 5] [1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG3]], 0] [2, 6, 1, 5] [1, 1, 1, 1]
 // CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:     %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
 // CHECK:     %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
@@ -427,7 +425,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2]
 // CHECK:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
 // CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
-// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, %[[ARG3]], 0] [2, 4, 1, 2] [1, 1, 1, 1]
 // CHECK:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
 // CHECK:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
 // CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
@@ -435,7 +433,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
 // CHECK:       %[[S11:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
 // CHECK:       %[[S12:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE_6]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S11]] : tensor<4x1xf32>) -> tensor<4x1xf32>
-// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S12]] : f32, tensor<4x1xf32>) outs(%[[EXTRACTED_SLICE_7]] : tensor<4x1xf32>) {
+// CHECK:       %[[S13:.*]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S12]] : f32, tensor<4x1xf32>) outs(%[[EXTRACTED_SLICE_7]] : tensor<4x1xf32>) {
 // CHECK:       ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
 // CHECK:          %[[VAL_57:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
 // CHECK:          %[[VAL_58:.*]] = arith.addf %[[VAL_57]], %[[OUT]] : f32
@@ -444,6 +442,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 4, 1, 1] [1, 1, 1, 1]
 // CHECK:       scf.yield %[[INSERTED_SLICE_8]]
 // CHECK:     scf.yield %[[S9]]
-// CHECK:     %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S8]] into %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK:     %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S8]] into %[[ARG4]][0, 0, %[[ARG3]], 0] [2, 4, 1, 2] [1, 1, 1, 1]
 // CHECK:     scf.yield %[[INSERTED_SLICE]]
 // CHECK:   return %[[S7]]
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
index 336b69160dbe87..fc6424fd4c8125 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -231,8 +231,9 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<() -> (6)>
 // CHECK-LABEL: func.func @tile_winograd_input(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<2x1x10x5xf32>, %[[ARG1:.*]]: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
@@ -251,11 +252,11 @@ module attributes {transform.with_named_sequence} {
 // CHECK:     %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]] iter_args(%[[ARG5:.*]] = %[[ARG3]])
 // CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
 // CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]] iter_args(%[[ARG9:.*]] = %[[ARG7]])
-// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
-// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP1]]()
-// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP1]]()
-// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], 0, %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
+// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP]](%[[ARG2]])
+// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG4]])
+// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP2]]()
+// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP2]]()
+// CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
 // CHECK:           %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
 // CHECK:           %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
 // CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG9]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x1x1x1xf32> into tensor<1x6x1x2x2x5xf32>
@@ -357,8 +358,9 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
-// CHECK: #[[$MAP1:.+]] = affine_map<() -> (4)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<() -> (4)>
 // CHECK-LABEL: func.func @tile_winograd_output(
 // CHECK-SAME:  %[[ARG0:.*]]: tensor<6x1x2x1x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
@@ -378,9 +380,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]] iter_args(%[[ARG11:.*]] = %[[ARG10]]) -> (tensor<3x8x1x5xf32>)
 // CHECK:         %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]] iter_args(%[[ARG12:.*]] = %[[ARG11]]) -> (tensor<3x8x1x5xf32>)
 // CHECK:           %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
-// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
-// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP1]]()
-// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP1]]()
-// CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
+// CHECK:           %[[S5:.*]] = affine.apply #[[$MAP]](%[[ARG2]])
+// CHECK:           %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG4]])
+// CHECK:           %[[S7:.*]] = affine.apply #[[$MAP2]]()
+// CHECK:           %[[S8:.*]] = affine.apply #[[$MAP2]]()
+// CHECK:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
 // CHECK:           %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>



More information about the Mlir-commits mailing list