[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)

Hsiangkai Wang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jun 20 05:48:04 PDT 2024


https://github.com/Hsiangkai created https://github.com/llvm/llvm-project/pull/96184

In order to support arbitrary size input data of conv2d, implement TilingInterface for winograd operators. Before converting winograd operators into nested loops with matrix multiply, tile the input of conv2d into the supported size first.

Add a transform operator structured.decompose_winograd_op to decompose winograd operators. Before applying the transform op, use tile_using_for to tile the input data into supported size. The test case shows how to tile and decompose winograd operators.

>From 7300578082fb321a0617ed2b61202eca39989e59 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:44:27 +0100
Subject: [PATCH] [mlir][linalg] Implement TilingInterface for winograd
 operators

In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operators. Before converting winograd
operators into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operator structured.decompose_winograd_op to decompose
winograd operators. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operators.
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  21 +-
 .../Linalg/TransformOps/LinalgTransformOps.td |  37 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |  45 +++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 281 +++++++++++++++
 .../TransformOps/LinalgTransformOps.cpp       |  27 ++
 .../Linalg/Transforms/WinogradConv2D.cpp      |  18 +
 .../transform-tile-and-winograd-rewrite.mlir  | 332 ++++++++++++++++++
 7 files changed, 758 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index de1097b6ac27b..45726d6ee2224 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,7 +154,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+     ["getIterationDomain",
+      "getLoopIteratorTypes",
+      "getResultTilePosition",
+      "getTiledImplementation"]>]> {
   let summary = "Winograd filter transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -192,7 +197,12 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd input transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -230,7 +240,12 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
   let hasVerifier = 1;
 }
 
-def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
   let summary = "Winograd output transform operator";
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..71736eae38b4f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2638,4 +2638,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 }
 
+def DecomposeWinogradOp : Op<Transform_Dialect,
+    "structured.decompose_winograd_op",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Decompose winograd operators. It will convert filter, input and output
+    transform operators into a combination of scf, tensor, and linalg
+    equivalent operators. Before applying this transform operator, users
+    need to tile winograd transform operators into supported sizes.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    operator.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bb7ec590faad0..d0eec2be1f8fb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1319,6 +1319,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
                                       linalg::Conv2DNhwcFhwcOp op, int64_t m,
                                       int64_t r);
 
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %f = lo_f to hi_f step 1
+///     %extracted = extract input<h x w> from result<h x w x n x f>
+///     %ret = linalg.matmul AT, %extracted
+///     %ret = linalg.matmul %ret, A
+///     %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7bf2a5bca037f..a416e1f6e257f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value output = getOutput();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+                               Location loc) {
+  if (auto val = opFoldResult.dyn_cast<Value>()) {
+    return val;
+  } else if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+    auto intAttr = cast<IntegerAttr>(attr);
+    return builder.create<arith::ConstantOp>(loc, intAttr);
+  }
+  // This should never happen if OpFoldResult is correctly formed.
+  return nullptr;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  resultOffsets.push_back(offsets[0]);
+  resultOffsets.push_back(offsets[1]);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(sizes[5]);
+
+  return success();
+}
+
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(6, oneAttr);
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(getFilter());
+
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // WinogradInputTransformOp
 //===----------------------------------------------------------------------===//
@@ -2786,6 +2869,112 @@ LogicalResult WinogradInputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value output = getOutput();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  resultOffsets.push_back(offsets[0]);
+  resultOffsets.push_back(offsets[1]);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(sizes[5]);
+
+  return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  auto context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(mappedOffset1);
+  sliceOffsets.push_back(mappedOffset2);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[4]);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[5]);
+  SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // WinogradOutputTransformOp
 //===----------------------------------------------------------------------===//
@@ -2812,6 +3001,98 @@ LogicalResult WinogradOutputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value value = getValue();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  int64_t m = getM();
+  IntegerAttr mAttr = getMAttr();
+  Location loc = getLoc();
+  auto context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(mappedOffset1);
+  resultOffsets.push_back(mappedOffset2);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(mAttr);
+  resultSizes.push_back(mAttr);
+  resultSizes.push_back(sizes[5]);
+  return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.push_back(offsets[0]);
+  sliceOffsets.push_back(offsets[1]);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(oneAttr);
+  sliceSizes.push_back(oneAttr);
+  sliceSizes.push_back(sizes[2]);
+  sliceSizes.push_back(sizes[3]);
+  sliceSizes.push_back(sizes[4]);
+  sliceSizes.push_back(sizes[5]);
+  SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  SmallVector<OpFoldResult> strides(4, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d051b29e1f06f..358c15f145407 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3505,6 +3505,33 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::WinogradFilterTransformOp op) {
+            return decomposeWinogradFilterTransformOp(rewriter, op);
+          })
+          .Case([&](linalg::WinogradInputTransformOp op) {
+            return decomposeWinogradInputTransformOp(rewriter, op);
+          })
+          .Case([&](linalg::WinogradOutputTransformOp op) {
+            return decomposeWinogradOutputTransformOp(rewriter, op);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d245723c85646..7cbd8ed9d44e8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -1083,6 +1083,24 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
   return winogradConv2DHelper(rewriter, op, m, r);
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op) {
+  return decomposeWinogradFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op) {
+  return decomposeWinogradInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op) {
+  return decomposeWinogradOutputTransformHelper(rewriter, op);
+}
+
 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
                                     int64_t r) {
   MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 0000000000000..39aeea1770101
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,332 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.empty() : tensor<2x8x8x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x8x8x2xf32>
+  %2 = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
+  %4 = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+  %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%4 : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+  %6 = tensor.empty() : tensor<144x2x2xf32>
+  %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%6 : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+  %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+  %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<2x2x6x6x2x2xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+  return %8 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %6 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %1 : (!transform.any_op) -> !transform.any_op
+    %7 = transform.structured.decompose_winograd_op %6 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG:    %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:    %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:    %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:    %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:    %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<2x2x6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x2x6x6x5x2xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S2]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<2x2x6x6x5x2xf32> to tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_7]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_8]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S19]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[INSERTED_SLICE_9]] into %[[ARG10]][0, 0, 0, 0, %[[ARG9]], %[[ARG7]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_10]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x5x2xf32> into tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[S5:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<2x2x6x6x2x5xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x2x6x6x2x5xf32>) {
+// CHECK-NEXT:   %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:   %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<2x2x6x6x2x5xf32> to tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE]]_9 : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x5xf32> into tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %4 {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT:   %[[S8:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S8]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+// CHECK-NEXT:   %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   %[[S11:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<2x2x6x6x2x2xf32> to tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S1]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:         %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:           %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP4]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S20]] : f32, tensor<4x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[IN_12:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:             %[[S24:.*]] = arith.mulf %[[IN]], %[[IN_12]] : f32
+// CHECK-NEXT:             linalg.yield %[[S24]] : f32
+// CHECK-NEXT:           } -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S16]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return %[[S11]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+  %0 = tensor.empty() : tensor<2x9x9x2xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x9x9x2xf32>
+  %2 = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+  %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
+  %4 = tensor.empty() : tensor<2x14x14x5xf32>
+  %inserted_slice = tensor.insert_slice %arg0 into %4[0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+  %5 = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+  %6 = linalg.winograd_input_transform m(4) r(3) ins(%inserted_slice : tensor<2x14x14x5xf32>) outs(%5 : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
+  %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %6 [[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+  %7 = tensor.empty() : tensor<324x2x2xf32>
+  %8 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%7 : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+  %expanded = tensor.expand_shape %8 [[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+  %9 = tensor.empty() : tensor<2x12x12x2xf32>
+  %inserted_slice_1 = tensor.insert_slice %1 into %9[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+  %10 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<3x3x6x6x2x2xf32>) outs(%inserted_slice_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+  %extracted_slice = tensor.extract_slice %10[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+  return %extracted_slice : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %6 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %1 : (!transform.any_op) -> !transform.any_op
+    %7 = transform.structured.decompose_winograd_op %6 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG:    %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:    %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:    %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:    %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:    %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:    %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:    %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:    %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
+// CHECK-NEXT:   ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:     linalg.yield %[[IN]] : f32
+// CHECK-NEXT:   } -> tensor<2x9x9x2xf32>
+// CHECK-NEXT:   %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S3:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:   %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<3x3x6x6x5x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<3x3x6x6x5x2xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S2]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<3x3x6x6x5x2xf32> to tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x5x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_7]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x3xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_8]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S19]] : tensor<6x3xf32>) -> tensor<6x3xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[INSERTED_SLICE_9]] into %[[ARG10]][0, 0, 0, 0, %[[ARG9]], %[[ARG7]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_10]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<1x1x6x6x5x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x5x2xf32> into tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[INSERTED_INPUT_BUF:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+// CHECK-NEXT:   %[[S5:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S6:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<3x3x6x6x2x5xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<3x3x6x6x2x5xf32>) {
+// CHECK-NEXT:   %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:   %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[INSERTED_INPUT_BUF]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<3x3x6x6x2x5xf32> to tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:         %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x2x5xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<6x6xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S18]] : tensor<1x1x6x6x2x5xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x5xf32> into tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[COLLAPSED:.*]] = tensor.collapse_shape %4 {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+// CHECK-NEXT:   %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+// CHECK-NEXT:   %[[S8:.*]] = tensor.empty() : tensor<324x2x2xf32>
+// CHECK-NEXT:   %[[S9:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S8]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+// CHECK-NEXT:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+// CHECK-NEXT:   %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT:  %[[INSERTED_OUTPUT_BUF:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[S10:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT:   %[[S11:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:     %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK-NEXT:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<3x3x6x6x2x2xf32> to tensor<1x1x6x6x2x2xf32>
+// CHECK-NEXT:       %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[INSERTED_OUTPUT_BUF]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
+// CHECK-NEXT:       %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:         %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32>
+// CHECK-NEXT:           %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32>
+// CHECK-NEXT:           %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:           %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:           %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S20:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S21:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:           %[[S22:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP4]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S20]] : f32, tensor<4x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) {
+// CHECK-NEXT:           ^bb0(%[[IN:.*]]: f32, %[[IN_12:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT:             %[[S24:.*]] = arith.mulf %[[IN]], %[[IN_12]] : f32
+// CHECK-NEXT:             linalg.yield %[[S24]] : f32
+// CHECK-NEXT:           } -> tensor<4x4xf32>
+// CHECK-NEXT:           %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK-NEXT:           %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:           scf.yield %[[INSERTED_SLICE_11]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:         }
+// CHECK-NEXT:         scf.yield %[[S16]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:       }
+// CHECK-NEXT:       %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]])
+// CHECK-NEXT:       %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]])
+// CHECK-NEXT:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT:       scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     scf.yield %[[S12]] : tensor<2x12x12x2xf32>
+// CHECK-NEXT:   }
+// CHECK-NEXT:   %[[RET:.*]] = tensor.extract_slice %[[S11]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32
+// CHECK-NEXT:   return %[[RET]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }



More information about the llvm-branch-commits mailing list