[Mlir-commits] [mlir] c4bf949 - [mlir][linalg] Implement TilingInterface for winograd operators (#96184)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 16 08:22:07 PDT 2024
Author: Hsiangkai Wang
Date: 2024-08-16T16:22:02+01:00
New Revision: c4bf949171a72383d5ba4d2b587d4cc496aacebb
URL: https://github.com/llvm/llvm-project/commit/c4bf949171a72383d5ba4d2b587d4cc496aacebb
DIFF: https://github.com/llvm/llvm-project/commit/c4bf949171a72383d5ba4d2b587d4cc496aacebb.diff
LOG: [mlir][linalg] Implement TilingInterface for winograd operators (#96184)
In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operations. Before converting winograd
operations into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.
Add a transform operation structured.decompose_winograd_op to decompose
winograd operations. Before applying the transform op, use
tile_using_for to tile the input data into supported size. The test case
shows how to tile and decompose winograd operations.
Added:
mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078e..5b6a90f806bedd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
-def Linalg_WinogradFilterTransformOp :
- Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+ [AllElementTypesMatch<["filter", "output"]>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
+ let extraClassDeclaration = [{
+ ShapedType getFilterOperandType() {
+ return cast<ShapedType>(getFilter().getType());
+ }
+ ShapedType getOutputOperandType() {
+ return cast<ShapedType>(getOutput().getType());
+ }
+ int64_t getFilterOperandRank() {
+ return getFilterOperandType().getRank();
+ }
+ int64_t getOutputOperandRank() {
+ return getOutputOperandType().getRank();
+ }
+ int64_t getFilterFDim() {
+ return 0;
+ }
+ int64_t getFilterHDim() {
+ return 1;
+ }
+ int64_t getFilterWDim() {
+ return 2;
+ }
+ int64_t getFilterCDim() {
+ return 3;
+ }
+ }];
let hasVerifier = 1;
}
-def Linalg_WinogradInputTransformOp :
- Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+ [AllElementTypesMatch<["input", "output"]>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
+ let extraClassDeclaration = [{
+ ShapedType getInputOperandType() {
+ return cast<ShapedType>(getInput().getType());
+ }
+ ShapedType getOutputOperandType() {
+ return cast<ShapedType>(getOutput().getType());
+ }
+ int64_t getInputOperandRank() {
+ return getInputOperandType().getRank();
+ }
+ int64_t getOutputOperandRank() {
+ return getOutputOperandType().getRank();
+ }
+ int64_t getInputNDim() {
+ return 0;
+ }
+ int64_t getInputHDim() {
+ return 1;
+ }
+ int64_t getInputWDim() {
+ return 2;
+ }
+ int64_t getInputCDim() {
+ return 3;
+ }
+ int64_t getOutputAlphaHDim() {
+ return 0;
+ }
+ int64_t getOutputAlphaWDim() {
+ return 1;
+ }
+ int64_t getOutputTileHDim() {
+ return 2;
+ }
+ int64_t getOutputTileWDim() {
+ return 3;
+ }
+ int64_t getOutputNDim() {
+ return 4;
+ }
+ int64_t getOutputCDim() {
+ return 5;
+ }
+ }];
let hasVerifier = 1;
}
-def Linalg_WinogradOutputTransformOp :
- Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+ [AllElementTypesMatch<["value", "output"]>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
+ let extraClassDeclaration = [{
+ ShapedType getValueOperandType() {
+ return cast<ShapedType>(getValue().getType());
+ }
+ ShapedType getOutputOperandType() {
+ return cast<ShapedType>(getOutput().getType());
+ }
+ int64_t getValueOperandRank() {
+ return getValueOperandType().getRank();
+ }
+ int64_t getOutputOperandRank() {
+ return getOutputOperandType().getRank();
+ }
+ int64_t getValueAlphaHDim() {
+ return 0;
+ }
+ int64_t getValueAlphaWDim() {
+ return 1;
+ }
+ int64_t getValueTileHDim() {
+ return 2;
+ }
+ int64_t getValueTileWDim() {
+ return 3;
+ }
+ int64_t getValueNDim() {
+ return 4;
+ }
+ int64_t getValueFDim() {
+ return 5;
+ }
+ int64_t getOutputNDim() {
+ return 0;
+ }
+ int64_t getOutputHDim() {
+ return 1;
+ }
+ int64_t getOutputWDim() {
+ return 2;
+ }
+ int64_t getOutputFDim() {
+ return 3;
+ }
+ }];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ecc86999006db6..106f0d79d9792d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
}
+def DecomposeWinogradOp : Op<Transform_Dialect,
+ "structured.decompose_winograd_op",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Decompose winograd operations. It will convert filter, input and output
+ transform operations into a combination of scf, tensor, and linalg
+ equivalent operations. Before applying this transform operations, users
+ need to tile winograd transform operations into supported sizes.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported. Otherwise, the operation
+ succeeds and returns a handle of the sequence that replaces the original
+ operations.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 477ef7bfafb181..861e14d22d9625 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1316,6 +1316,63 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
+/// the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract filter<h x w> from filter<f x h x w x c>
+/// %ret = linalg.matmul G, %extracted
+/// %ret = linalg.matmul %ret, GT
+/// %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
+/// and tileW. After the rewriting, we get
+///
+/// scf.for %h = 0 to tileH step 1
+/// scf.for %w = 0 to tileW step 1
+/// scf.for %n = 0 to N step 1
+/// scf.for %c = 0 to C step 1
+/// %extracted = extract %extracted<alphaH x alphaW> from
+/// %input<N x H x W x C>
+/// at [%n, (%h x m), (%w x m), %c]
+/// %ret = linalg.matmul BT, %extracted
+/// %ret = linalg.matmul %ret, B
+/// %inserted = insert %ret<alphaH x alphaW> into
+/// %output<alphaH x alphaW x tileH x tileW x N x C>
+/// at [0, 0, %h, %w, %n, %c]
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
+/// and tileW. After the transformation, we get
+///
+/// scf.for %h = 0 to tileH step 1
+/// scf.for %w = 0 to tileW step 1
+/// scf.for %n = 0 to N step 1
+/// scf.for %f = 0 to F step 1
+/// %extracted = extract %extracted<alphaH x alphaW> from
+/// %input<alphaH x alphaW x tileH x tileW x N x F>
+/// at [0, 0, %h, %w, %n, %f]
+/// %ret = linalg.matmul AT, %extracted
+/// %ret = linalg.matmul %ret, A
+/// %inserted = insert %ret<alphaH x alphaW> into
+/// output<N x H x W x F>
+/// at [%n, (%h x m), (%w x m), %f]
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a101552e419bc8..775ed8f37344ed 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2855,8 +2855,8 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
LogicalResult WinogradFilterTransformOp::verify() {
auto filterType = cast<ShapedType>(getFilter().getType());
ArrayRef<int64_t> filterShape = filterType.getShape();
- int64_t filterH = filterShape[1];
- int64_t filterW = filterShape[2];
+ int64_t filterH = filterShape[getFilterHDim()];
+ int64_t filterW = filterShape[getFilterWDim()];
int64_t r = getR();
int64_t m = getM();
@@ -2870,8 +2870,8 @@ LogicalResult WinogradFilterTransformOp::verify() {
SmallVector<int64_t> expectedOutputShape;
expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
- expectedOutputShape.push_back(filterShape[3]);
- expectedOutputShape.push_back(filterShape[0]);
+ expectedOutputShape.push_back(filterShape[getFilterCDim()]);
+ expectedOutputShape.push_back(filterShape[getFilterFDim()]);
auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -2881,6 +2881,103 @@ LogicalResult WinogradFilterTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ IntegerAttr zeroAttr = builder.getIndexAttr(0);
+ IntegerAttr oneAttr = builder.getIndexAttr(1);
+ Value filter = getFilter();
+ int64_t filterRank = getFilterOperandRank();
+ SmallVector<Range> loopBounds(filterRank);
+ for (unsigned dim = 0; dim < filterRank; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+ int64_t filterRank = getFilterOperandRank();
+ SmallVector<utils::IteratorType> iteratorTypes(filterRank,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ ShapedType filterType = getFilterOperandType();
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ int64_t filterH = filterShape[getFilterHDim()];
+ int64_t filterW = filterShape[getFilterWDim()];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = filterH != 1 ? alpha : 1;
+ int64_t alphaW = filterW != 1 ? alpha : 1;
+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ resultOffsets.append(
+ {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
+ resultSizes.append(
+ {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
+
+ return success();
+}
+
+/// Implement tiling for winograd_filter_transform
+/// The input of winograd_filter_transform is (F, KH, KW, C).
+/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
+/// Users can specify the tile sizes of F and C.
+/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
+/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ ShapedType filterType = getFilterOperandType();
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ int64_t filterH = filterShape[getFilterHDim()];
+ int64_t filterW = filterShape[getFilterWDim()];
+ IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
+ IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ sliceOffsets.append(
+ {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
+ sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
+ sizes[getFilterCDim()]});
+ int64_t filterRank = getFilterOperandRank();
+ SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
+ Location loc = getLoc();
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+ resultSizes)))
+ return failure();
+
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
@@ -2888,8 +2985,8 @@ LogicalResult WinogradFilterTransformOp::verify() {
LogicalResult WinogradInputTransformOp::verify() {
auto inputType = cast<ShapedType>(getInput().getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputH = inputShape[1];
- int64_t inputW = inputShape[2];
+ int64_t inputH = inputShape[getInputHDim()];
+ int64_t inputW = inputShape[getInputWDim()];
int m = getM();
int r = getR();
int64_t tileSize = m + r - 1;
@@ -2898,21 +2995,23 @@ LogicalResult WinogradInputTransformOp::verify() {
SmallVector<int64_t> expectedOutputShape(6, inputH);
if (ShapedType::isDynamic(inputH)) {
- expectedOutputShape[0] = tileSize;
- expectedOutputShape[2] = ShapedType::kDynamic;
+ expectedOutputShape[getOutputAlphaHDim()] = tileSize;
+ expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
} else {
- expectedOutputShape[0] = leftTransform ? tileSize : 1;
- expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
+ expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
+ expectedOutputShape[getOutputTileHDim()] =
+ leftTransform ? (inputH - (r - 1)) / m : 1;
}
if (ShapedType::isDynamic(inputW)) {
- expectedOutputShape[1] = tileSize;
- expectedOutputShape[3] = ShapedType::kDynamic;
+ expectedOutputShape[getOutputAlphaWDim()] = tileSize;
+ expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
} else {
- expectedOutputShape[1] = rightTransform ? tileSize : 1;
- expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
+ expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
+ expectedOutputShape[getOutputTileWDim()] =
+ rightTransform ? (inputW - (r - 1)) / m : 1;
}
- expectedOutputShape[4] = inputShape[0];
- expectedOutputShape[5] = inputShape[3];
+ expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
+ expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -2922,6 +3021,130 @@ LogicalResult WinogradInputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ IntegerAttr zeroAttr = builder.getIndexAttr(0);
+ IntegerAttr oneAttr = builder.getIndexAttr(1);
+ Value output = getOutput();
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<Range> loopBounds(outputRank);
+ for (unsigned dim = 0; dim < outputRank; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ // alphaH, alphaW, tileH, tileW, N, C
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<utils::IteratorType> iteratorTypes(outputRank,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ ShapedType inputType = getInputOperandType();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputH = inputShape[getInputHDim()];
+ int64_t inputW = inputShape[getInputWDim()];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = inputH != 1 ? alpha : 1;
+ int64_t alphaW = inputW != 1 ? alpha : 1;
+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
+ offsets[getOutputTileWDim()], offsets[getOutputNDim()],
+ offsets[getOutputCDim()]});
+ resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
+ sizes[getOutputTileWDim()], sizes[getOutputNDim()],
+ sizes[getOutputCDim()]});
+
+ return success();
+}
+
+/// Implement tiling for winograd_input_transform
+/// The input of winograd_input_transform is (N, H, W, C).
+/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
+/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
+/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
+/// the values for the sizes of tileH, tileW, N, C for one tile.
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ ShapedType inputType = getInputOperandType();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputH = inputShape[getInputHDim()];
+ int64_t inputW = inputShape[getInputWDim()];
+ int64_t m = getM();
+ int64_t r = getR();
+
+ Location loc = getLoc();
+ MLIRContext *context = builder.getContext();
+ auto offsetAffineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffsetH = affine::makeComposedAffineApply(
+ builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
+ Value mappedOffsetW = affine::makeComposedAffineApply(
+ builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
+ auto sizeAffineMap = AffineMap::get(
+ 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
+ Value mappedSizeH = affine::makeComposedAffineApply(
+ builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
+ Value mappedSizeW = affine::makeComposedAffineApply(
+ builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
+
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ OpFoldResult offsetH =
+ inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
+ OpFoldResult offsetW =
+ inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+ sliceOffsets.append(
+ {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
+ OpFoldResult sizeH =
+ inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+ OpFoldResult sizeW =
+ inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+ sliceSizes.append(
+ {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
+ int64_t inputRank = getInputOperandRank();
+ SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+ resultSizes)))
+ return failure();
+
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides));
+
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//
@@ -2929,32 +3152,34 @@ LogicalResult WinogradInputTransformOp::verify() {
LogicalResult WinogradOutputTransformOp::verify() {
auto valueType = cast<ShapedType>(getValue().getType());
ArrayRef<int64_t> valueShape = valueType.getShape();
- int64_t valueH = valueShape[0];
- int64_t valueW = valueShape[1];
- int64_t valueTileH = valueShape[2];
- int64_t valueTileW = valueShape[3];
+ int64_t valueH = valueShape[getValueAlphaHDim()];
+ int64_t valueW = valueShape[getValueAlphaWDim()];
+ int64_t valueTileH = valueShape[getValueTileHDim()];
+ int64_t valueTileW = valueShape[getValueTileWDim()];
int m = getM();
int r = getR();
bool leftTransform = valueH != 1;
bool rightTransform = valueW != 1;
- SmallVector<int64_t> expectedOutputShape(4, valueH);
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
- expectedOutputShape[1] = ShapedType::kDynamic;
+ expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
} else {
if (valueH != (leftTransform ? m + r - 1 : 1))
return emitOpError("expect input height equals to input tile size");
- expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
+ expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
}
if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
- expectedOutputShape[2] = ShapedType::kDynamic;
+ expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
} else {
if (valueW != (rightTransform ? m + r - 1 : 1))
return emitOpError("expect input width equals to input tile size");
- expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
+ expectedOutputShape[getOutputWDim()] =
+ (rightTransform ? m : 1) * valueTileW;
}
- expectedOutputShape[0] = valueShape[4];
- expectedOutputShape[3] = valueShape[5];
+ expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
+ expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -2964,6 +3189,124 @@ LogicalResult WinogradOutputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ IntegerAttr zeroAttr = builder.getIndexAttr(0);
+ IntegerAttr oneAttr = builder.getIndexAttr(1);
+ Value value = getValue();
+ int64_t valueRank = getValueOperandRank();
+ SmallVector<Range> loopBounds(valueRank);
+ for (unsigned dim = 0; dim < valueRank; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ // alphaH, alphaW, tileH, tileW, N, F
+ loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+ int64_t valueRank = getValueOperandRank();
+ SmallVector<utils::IteratorType> iteratorTypes(valueRank,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ int64_t m = getM();
+
+ Location loc = getLoc();
+ MLIRContext *context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+
+ Value mappedOffsetH = affine::makeComposedAffineApply(
+ builder, loc, affineMap, offsets[getValueTileHDim()]);
+ Value mappedOffsetW = affine::makeComposedAffineApply(
+ builder, loc, affineMap, offsets[getValueTileWDim()]);
+ Value mappedSizeH = affine::makeComposedAffineApply(
+ builder, loc, affineMap, sizes[getValueTileHDim()]);
+ Value mappedSizeW = affine::makeComposedAffineApply(
+ builder, loc, affineMap, sizes[getValueTileWDim()]);
+
+ ShapedType valueType = getValueOperandType();
+ ArrayRef<int64_t> valueShape = valueType.getShape();
+ int64_t valueH = valueShape[0];
+ int64_t valueW = valueShape[1];
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ OpFoldResult offsetH =
+ valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
+ OpFoldResult offsetW =
+ valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
+ OpFoldResult sizeH =
+ valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+ OpFoldResult sizeW =
+ valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+
+ resultOffsets.append(
+ {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
+ resultSizes.append(
+ {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
+ return success();
+}
+
+/// Implement tiling for winograd_output_transform
+/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
+/// F). The output of winograd_output_transform is (N, H, W, F) Users can
+/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
+/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
+/// for the sizes of tileH, tileW, N, F for one tile.
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ ShapedType valueType = getValueOperandType();
+ ArrayRef<int64_t> valueShape = valueType.getShape();
+ int64_t alphaH = valueShape[getValueAlphaHDim()];
+ int64_t alphaW = valueShape[getValueAlphaWDim()];
+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
+ offsets[getValueTileWDim()], offsets[getValueNDim()],
+ offsets[getValueFDim()]});
+ sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
+ sizes[getValueTileWDim()], sizes[getValueNDim()],
+ sizes[getValueFDim()]});
+ int64_t valueRank = getValueOperandRank();
+ SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+ resultSizes)))
+ return failure();
+
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<OpFoldResult> strides(outputRank, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), resultOffsets, resultSizes, strides));
+
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 48b3abbeee7010..fbf4e29024f7c2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3851,6 +3851,47 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ FailureOr<Operation *> maybeTransformed = failure();
+ bool supported =
+ TypeSwitch<Operation *, bool>(target)
+ .Case([&](linalg::WinogradFilterTransformOp op) {
+ maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
+ return true;
+ })
+ .Case([&](linalg::WinogradInputTransformOp op) {
+ maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
+ return true;
+ })
+ .Case([&](linalg::WinogradOutputTransformOp op) {
+ maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
+ return true;
+ })
+ .Default([&](Operation *op) { return false; });
+
+ if (!supported) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "this operation is not supported to decompose into other operations";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ if (supported && failed(maybeTransformed)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "decompose Winograd operations failed";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c6c770e2781ff0..b65b18699a15aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -490,8 +490,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
Type elementType = inputType.getElementType();
auto inputShape = inputType.getShape(); // N, H, W, C
int64_t inputN = inputShape[0];
- int64_t inputH = inputShape[1];
- int64_t inputW = inputShape[2];
int64_t inputC = inputShape[3];
auto valueType = cast<ShapedType>(retValue.getType());
auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
@@ -500,11 +498,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
int64_t alphaH = leftTransform ? m + r - 1 : 1;
int64_t alphaW = rightTransform ? m + r - 1 : 1;
- if ((inputH != (tileH * m) + (r - 1)) && inputH != 1)
- return Value();
- if ((inputW != (tileW * m) + (r - 1)) && inputW != 1)
- return Value();
-
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
Value tileHIter = ivs[0];
@@ -1169,6 +1162,24 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
return winogradConv2DHelper(rewriter, op, m, r);
}
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op) {
+ return decomposeWinogradFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op) {
+ return decomposeWinogradInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op) {
+ return decomposeWinogradOutputTransformHelper(rewriter, op);
+}
+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 00000000000000..6bb3fb1423edc6
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,292 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.empty() : tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+ %4 = tensor.empty() : tensor<36x8x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %6 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty()
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: scf.yield %[[S13]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S19:.*]] = linalg.matmul
+// CHECK: %[[S20:.*]] = tensor.empty()
+// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: scf.yield %[[S15]]
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+ tensor.yield %cst : f32
+ } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+ %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+ %4 = tensor.empty() : tensor<36x18x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+ %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+ tensor.yield %cst : f32
+ } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+ return %extracted_slice : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty()
+// CHECK: %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S19:.*]] = linalg.matmul
+// CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_12]]
+// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK: return %[[EXTRACTED_SLICE]]
+
+// -----
+
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+ %0 = tensor.empty() : tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+ %4 = tensor.empty() : tensor<6x2x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ return %6 : tensor<2x4x1x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S5:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
+// CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK: %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x1xf32>
+// CHECK: %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: return %[[S6]]
diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
new file mode 100644
index 00000000000000..21522a2083b463
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir
@@ -0,0 +1,380 @@
+// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+ %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ return %0 : tensor<6x6x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32>
+// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+ %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ return %0 : tensor<6x6x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]]
+// CHECK: %[[C5_2:.*]] = arith.constant 5 : index
+// CHECK: %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32>
+// CHECK: %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32>
+// CHECK: %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32>
+
+// -----
+
+func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
+ %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ return %0 : tensor<6x1x5x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @tile_winograd_filter(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x5xf32>, %[[ARG1:.*]]: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32>
+// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+ %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop3:2 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK: %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK: %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S5:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32>
+// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+ %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
+// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32>
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+ %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ return %0 : tensor<6x6x2x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<() -> (10)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_7:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_6:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_8:.*]] = arith.constant 2 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
+// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]]
+// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]]
+// CHECK: %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]])
+// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG2]])
+// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]](%[[ARG4]])
+// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]]()
+// CHECK: %[[S9:.*]] = affine.apply #[[$MAP2]]()
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32>
+// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32>
+// CHECK: %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32>
+
+// -----
+
+func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
+ %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32>
+ return %0 : tensor<1x6x1x2x2x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)>
+// CHECK-LABEL: func.func @tile_winograd_input(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x10x5xf32>, %[[ARG1:.*]]: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_4:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_7:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]]
+// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]]
+// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]]
+// CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], 0, %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32>
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %0 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (4)>
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<6x6x2x2x2x2xf32>, %[[ARG1:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK: %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK: %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S5:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
+ %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32>
+ return %0 : tensor<3x8x8x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 3, 2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (-d0 + 5, 2)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<() -> (8)>
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<6x6x2x2x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C2_7:.*]] = arith.constant 2 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
+// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]]
+// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]]
+// CHECK: %[[C3_8:.*]] = arith.constant 3 : index
+// CHECK: %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[C5_9:.*]] = arith.constant 5 : index
+// CHECK: %[[S6:.*]] = affine.min #[[$MAP1]](%[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, %[[S5]], %[[S6]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x3x5xf32> to tensor<6x6x2x2x?x?xf32>
+// CHECK: %[[S7:.*]] = affine.apply #[[$MAP2]](%[[ARG2]])
+// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]])
+// CHECK: %[[S9:.*]] = affine.apply #[[$MAP3]]()
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP3]]()
+// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
+
+// -----
+
+func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
+ %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32>
+ return %0 : tensor<3x8x1x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<() -> (4)>
+// CHECK-LABEL: func.func @tile_winograd_output(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<6x1x2x1x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index
+// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
+// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]]
+// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]]
+// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
+// CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]()
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
+// CHECK: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>
More information about the Mlir-commits
mailing list