[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
Hsiangkai Wang
llvmlistbot at llvm.org
Wed Jul 17 23:24:58 PDT 2024
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96184
>From 72477955477477f1d95e2125cd33122899154f0d Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:44:27 +0100
Subject: [PATCH 1/4] [mlir][linalg] Implement TilingInterface for winograd
operations
In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operations. Before converting winograd
operations into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.
Add a transform operation structured.decompose_winograd_op to decompose
winograd operations. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operations.
---
.../mlir/Dialect/Linalg/IR/LinalgOps.td | 22 +-
.../Linalg/TransformOps/LinalgTransformOps.td | 37 +++
.../Dialect/Linalg/Transforms/Transforms.h | 45 +++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 216 +++++++++++++++
.../TransformOps/LinalgTransformOps.cpp | 41 +++
.../Linalg/Transforms/WinogradConv2D.cpp | 18 ++
.../transform-tile-and-winograd-rewrite.mlir | 260 ++++++++++++++++++
7 files changed, 633 insertions(+), 6 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078..bc2df7082ef8e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,8 +154,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
-def Linalg_WinogradFilterTransformOp :
- Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+ [AllElementTypesMatch<["filter", "output"]>]> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -193,8 +193,13 @@ def Linalg_WinogradFilterTransformOp :
let hasVerifier = 1;
}
-def Linalg_WinogradInputTransformOp :
- Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+ [AllElementTypesMatch<["input", "output"]>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -232,8 +237,13 @@ def Linalg_WinogradInputTransformOp :
let hasVerifier = 1;
}
-def Linalg_WinogradOutputTransformOp :
- Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+ [AllElementTypesMatch<["value", "output"]>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ecc86999006db..106f0d79d9792 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
}
+def DecomposeWinogradOp : Op<Transform_Dialect,
+ "structured.decompose_winograd_op",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Decompose winograd operations. It will convert filter, input and output
+ transform operations into a combination of scf, tensor, and linalg
+ equivalent operations. Before applying this transform operations, users
+ need to tile winograd transform operations into supported sizes.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported. Otherwise, the operation
+ succeeds and returns a handle of the sequence that replaces the original
+ operations.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0c7a8edff222f..8ec698b8c0f59 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1339,6 +1339,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract filter<h x w> from filter<f x h x w x c>
+/// %ret = linalg.matmul G, %extracted
+/// %ret = linalg.matmul %ret, GT
+/// %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract input<h x w> from input<n x h x w x c>
+/// %ret = linalg.matmul BT, %extracted
+/// %ret = linalg.matmul %ret, B
+/// %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+/// scf.for %f = lo_f to hi_f step 1
+/// %extracted = extract input<h x w> from result<h x w x n x f>
+/// %ret = linalg.matmul AT, %extracted
+/// %ret = linalg.matmul %ret, A
+/// %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..c93aebb095784 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2776,6 +2776,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+ Location loc) {
+ if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+ auto intAttr = cast<IntegerAttr>(attr);
+ return builder.create<arith::ConstantOp>(loc, intAttr);
+ }
+ return opFoldResult.get<Value>();
+}
+
LogicalResult WinogradInputTransformOp::verify() {
auto inputType = cast<ShapedType>(getInput().getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2813,6 +2822,113 @@ LogicalResult WinogradInputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ auto indexType = builder.getIndexType();
+ auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+ auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ Value output = getOutput();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(offsets[2]);
+ resultOffsets.push_back(offsets[3]);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(sizes[0]);
+ resultSizes.push_back(sizes[1]);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(sizes[5]);
+
+ return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value input = getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = inputH != 1 ? alpha : 1;
+ int64_t alphaW = inputW != 1 ? alpha : 1;
+ auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ auto context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(mappedOffset1);
+ sliceOffsets.push_back(mappedOffset2);
+ sliceOffsets.push_back(zeroAttr);
+ sliceSizes.push_back(sizes[4]);
+ sliceSizes.push_back(alphaHAttr);
+ sliceSizes.push_back(alphaWAttr);
+ sliceSizes.push_back(sizes[5]);
+ SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+ sliceOffsets.clear();
+ sliceSizes.clear();
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//
@@ -2855,6 +2971,106 @@ LogicalResult WinogradOutputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ auto indexType = builder.getIndexType();
+ auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+ auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ Value value = getValue();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value output = getOutput();
+ auto outputType = cast<ShapedType>(output.getType());
+ auto outputShape = outputType.getShape();
+ int64_t outputH = outputShape[1];
+ int64_t outputW = outputShape[2];
+ int64_t m = getM();
+ auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+ auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+
+ Location loc = getLoc();
+ auto context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(mappedOffset1);
+ resultOffsets.push_back(mappedOffset2);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(heightM);
+ resultSizes.push_back(widthM);
+ resultSizes.push_back(sizes[5]);
+ return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(offsets[2]);
+ sliceOffsets.push_back(offsets[3]);
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(zeroAttr);
+ sliceSizes.push_back(sizes[0]);
+ sliceSizes.push_back(sizes[1]);
+ sliceSizes.push_back(oneAttr);
+ sliceSizes.push_back(oneAttr);
+ sliceSizes.push_back(sizes[4]);
+ sliceSizes.push_back(sizes[5]);
+ SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+
+ sliceOffsets.clear();
+ sliceSizes.clear();
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> strides(4, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b611347b8de2e..e279a6089302e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3748,6 +3748,47 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ FailureOr<Operation *> maybeTransformed = failure();
+ bool supported =
+ TypeSwitch<Operation *, bool>(target)
+ .Case([&](linalg::WinogradFilterTransformOp op) {
+ maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
+ return true;
+ })
+ .Case([&](linalg::WinogradInputTransformOp op) {
+ maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
+ return true;
+ })
+ .Case([&](linalg::WinogradOutputTransformOp op) {
+ maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
+ return true;
+ })
+ .Default([&](Operation *op) { return false; });
+
+ if (!supported) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "this operation is not supported to decompose into other operations";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ if (supported && failed(maybeTransformed)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "decompose Winograd operations failed";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 754f832e98eea..c9a6c906008a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -1170,6 +1170,24 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
return winogradConv2DHelper(rewriter, op, m, r);
}
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op) {
+ return decomposeWinogradFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op) {
+ return decomposeWinogradInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op) {
+ return decomposeWinogradOutputTransformHelper(rewriter, op);
+}
+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 0000000000000..29833324bb21b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,260 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.empty() : tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+ %4 = tensor.empty() : tensor<36x8x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %6 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK: scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK-NEXt: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S9:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+ tensor.yield %cst : f32
+ } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+ %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+ %4 = tensor.empty() : tensor<36x18x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+ %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+ tensor.yield %cst : f32
+ } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+ return %extracted_slice : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK: linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK: ^bb0({{.*}}):
+// CHECK: tensor.yield %[[CST_6]] : f32
+// CHECK: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK: linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK: ^bb0({{.*}}):
+// CHECK: tensor.yield %[[CST_6]] : f32
+// CHECK: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[S22:.*]] = linalg.mul
+// CHECK: %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK: }
+
+// -----
+
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+ %0 = tensor.empty() : tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+ %4 = tensor.empty() : tensor<6x2x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ return %6 : tensor<2x4x1x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
+// CHECK: %[[S5:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
+// CHECK: linalg.matmul
+// CHECK: %[[S12:.*]] = linalg.mul
+// CHECK: %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>
>From 936356fd49b7f83aa307eb3fdf46348ee151c149 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 17 Jul 2024 18:26:14 +0100
Subject: [PATCH 2/4] Address ftynse's comments
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 122 +++----
.../transform-tile-and-winograd-rewrite.mlir | 332 ++++++++++--------
2 files changed, 226 insertions(+), 228 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c93aebb095784..c02abf692784c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2776,15 +2776,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
-Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
- Location loc) {
- if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
- auto intAttr = cast<IntegerAttr>(attr);
- return builder.create<arith::ConstantOp>(loc, intAttr);
- }
- return opFoldResult.get<Value>();
-}
-
LogicalResult WinogradInputTransformOp::verify() {
auto inputType = cast<ShapedType>(getInput().getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2825,9 +2816,9 @@ LogicalResult WinogradInputTransformOp::verify() {
SmallVector<Range>
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
- auto indexType = builder.getIndexType();
- auto zeroAttr = builder.getIntegerAttr(indexType, 0);
- auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ IndexType indexType = builder.getIndexType();
+ IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+ IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
Value output = getOutput();
SmallVector<Range> loopBounds(6);
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2849,21 +2840,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
- auto zeroAttr = builder.getI64IntegerAttr(0);
- auto oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(offsets[2]);
- resultOffsets.push_back(offsets[3]);
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(zeroAttr);
- resultSizes.push_back(sizes[0]);
- resultSizes.push_back(sizes[1]);
- resultSizes.push_back(oneAttr);
- resultSizes.push_back(oneAttr);
- resultSizes.push_back(sizes[4]);
- resultSizes.push_back(sizes[5]);
+ resultOffsets.append(
+ {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+ resultSizes.append(
+ {sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
return success();
}
@@ -2872,11 +2855,11 @@ FailureOr<TilingResult>
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
- auto oneAttr = builder.getI64IntegerAttr(1);
- auto zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
Value input = getInput();
auto inputType = cast<ShapedType>(input.getType());
- auto inputShape = inputType.getShape();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
int64_t m = getM();
@@ -2884,29 +2867,25 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t alpha = m + r - 1;
int64_t alphaH = inputH != 1 ? alpha : 1;
int64_t alphaW = inputW != 1 ? alpha : 1;
- auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
- auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
Location loc = getLoc();
SmallVector<Value> tiledOperands;
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
- auto context = builder.getContext();
+ MLIRContext *context = builder.getContext();
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(mappedOffset1);
- sliceOffsets.push_back(mappedOffset2);
- sliceOffsets.push_back(zeroAttr);
- sliceSizes.push_back(sizes[4]);
- sliceSizes.push_back(alphaHAttr);
- sliceSizes.push_back(alphaWAttr);
- sliceSizes.push_back(sizes[5]);
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+ sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+ sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
SmallVector<OpFoldResult> inputStrides(4, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
@@ -2921,7 +2900,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
- SmallVector<Type, 4> resultTypes;
+ SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
@@ -2974,9 +2953,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
SmallVector<Range>
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
- auto indexType = builder.getIndexType();
- auto zeroAttr = builder.getIntegerAttr(indexType, 0);
- auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ IndexType indexType = builder.getIndexType();
+ IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+ IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
Value value = getValue();
SmallVector<Range> loopBounds(6);
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2998,57 +2977,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
- auto zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
Value output = getOutput();
auto outputType = cast<ShapedType>(output.getType());
- auto outputShape = outputType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
int64_t outputH = outputShape[1];
int64_t outputW = outputShape[2];
int64_t m = getM();
- auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
- auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+ IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+ IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
Location loc = getLoc();
- auto context = builder.getContext();
+ MLIRContext *context = builder.getContext();
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(mappedOffset1);
- resultOffsets.push_back(mappedOffset2);
- resultOffsets.push_back(zeroAttr);
- resultSizes.push_back(sizes[4]);
- resultSizes.push_back(heightM);
- resultSizes.push_back(widthM);
- resultSizes.push_back(sizes[5]);
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+ resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+ resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
return success();
}
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
- auto oneAttr = builder.getI64IntegerAttr(1);
- auto zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
Location loc = getLoc();
SmallVector<Value> tiledOperands;
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(offsets[2]);
- sliceOffsets.push_back(offsets[3]);
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(zeroAttr);
- sliceSizes.push_back(sizes[0]);
- sliceSizes.push_back(sizes[1]);
- sliceSizes.push_back(oneAttr);
- sliceSizes.push_back(oneAttr);
- sliceSizes.push_back(sizes[4]);
- sliceSizes.push_back(sizes[5]);
+ sliceOffsets.append(
+ {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+ sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
@@ -3063,7 +3029,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), sliceOffsets, sliceSizes, strides));
- SmallVector<Type, 4> resultTypes;
+ SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 29833324bb21b..6bb3fb1423edc 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -31,52 +31,77 @@ module attributes {transform.with_named_sequence} {
}
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func.func @conv2d
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK: scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK-NEXt: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S9:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK: %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
-// CHECK: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT: scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty()
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: scf.yield %[[S13]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S19:.*]] = linalg.matmul
+// CHECK: %[[S20:.*]] = tensor.empty()
+// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: scf.yield %[[S15]]
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
// -----
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<6x6x5x2xf32>
%1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
@@ -91,7 +116,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
%4 = tensor.empty() : tensor<36x18x2xf32>
%5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
%expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
- %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
tensor.yield %cst : f32
} : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
@@ -117,80 +142,82 @@ module attributes {transform.with_named_sequence} {
}
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
-// CHECK: linalg.matmul
-// CHECK: %[[S13:.*]] = linalg.matmul
-// CHECK: %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
-// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK: ^bb0({{.*}}):
-// CHECK: tensor.yield %[[CST_6]] : f32
-// CHECK: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
-// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK: linalg.matmul
-// CHECK: %[[S17:.*]] = linalg.matmul
-// CHECK: %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S6:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
-// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK: ^bb0({{.*}}):
-// CHECK: tensor.yield %[[CST_6]] : f32
-// CHECK: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
-// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[S22:.*]] = linalg.mul
-// CHECK: %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
-// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
-// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
-// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK: }
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty()
+// CHECK: %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S19:.*]] = linalg.matmul
+// CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_12]]
+// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK: return %[[EXTRACTED_SLICE]]
// -----
-func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
%0 = tensor.empty() : tensor<6x1x5x2xf32>
%1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
%2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
@@ -200,7 +227,7 @@ func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
%4 = tensor.empty() : tensor<6x2x2xf32>
%5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
%expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
- %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
return %6 : tensor<2x4x1x2xf32>
}
@@ -220,41 +247,46 @@ module attributes {transform.with_named_sequence} {
}
}
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func.func @conv2d_mx1_rx1
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C5:.*]] = arith.constant 5 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
-// CHECK: %[[S9:.*]] = linalg.matmul
-// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
-// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
-// CHECK: %[[S9:.*]] = linalg.matmul
-// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
-// CHECK: %[[S5:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
-// CHECK: linalg.matmul
-// CHECK: %[[S12:.*]] = linalg.mul
-// CHECK: %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S5:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
+// CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK: %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x1xf32>
+// CHECK: %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: return %[[S6]]
>From 73bbca1f9c780514f33774b77356d2e4a27d8292 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 18 Jul 2024 06:52:17 +0100
Subject: [PATCH 3/4] Address more comments
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 ++++++++----------
1 file changed, 8 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c02abf692784c..8901715ed144b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2890,15 +2890,14 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
- sliceOffsets.clear();
- sliceSizes.clear();
- if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
- sliceSizes)))
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+ resultSizes)))
return failure();
SmallVector<OpFoldResult> outputStrides(6, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+ loc, getOutput(), resultOffsets, resultSizes, outputStrides));
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
@@ -3019,15 +3018,14 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
- sliceOffsets.clear();
- sliceSizes.clear();
- if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
- sliceSizes)))
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+ resultSizes)))
return failure();
SmallVector<OpFoldResult> strides(4, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
- loc, getOutput(), sliceOffsets, sliceSizes, strides));
+ loc, getOutput(), resultOffsets, resultSizes, strides));
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
>From 508eea4c51d97f6fe7c4e954625eb0ad14b86941 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 18 Jul 2024 07:24:42 +0100
Subject: [PATCH 4/4] Update comments
---
.../Dialect/Linalg/Transforms/Transforms.h | 52 ++++++++++++-------
1 file changed, 32 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8ec698b8c0f59..f86eb2d71fad8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1341,8 +1341,8 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
-/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
-/// C. After the rewriting, we get
+/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
+/// the rewriting, we get
///
/// scf.for %f = lo_f to hi_f step 1
/// scf.for %c = lo_c to hi_c step 1
@@ -1356,30 +1356,42 @@ decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
-/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
-/// C. After the rewriting, we get
-///
-/// scf.for %n = lo_n to hi_n step 1
-/// scf.for %c = lo_c to hi_c step 1
-/// %extracted = extract input<h x w> from input<n x h x w x c>
-/// %ret = linalg.matmul BT, %extracted
-/// %ret = linalg.matmul %ret, B
-/// %inserted = insert %ret into input<h x w x n x c>
+/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
+/// and tileW. After the rewriting, we get
+///
+/// scf.for %h = 0 to tileH step 1
+/// scf.for %w = 0 to tileW step 1
+/// scf.for %n = 0 to N step 1
+/// scf.for %c = 0 to C step 1
+/// %extracted = extract %extracted<alphaH x alphaW> from
+/// %input<N x H x W x C>
+/// at [%n, (%h x m), (%w x m), %c]
+/// %ret = linalg.matmul BT, %extracted
+/// %ret = linalg.matmul %ret, B
+/// %inserted = insert %ret<alphaH x alphaW> into
+/// %output<alphaH x alphaW x tileH x tileW x N x C>
+/// at [0, 0, %h, %w, %n, %c]
FailureOr<Operation *>
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
linalg::WinogradInputTransformOp op);
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
-/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
-/// F. After the transformation, we get
-///
-/// scf.for %n = lo_n to hi_n step 1
-/// scf.for %f = lo_f to hi_f step 1
-/// %extracted = extract input<h x w> from result<h x w x n x f>
-/// %ret = linalg.matmul AT, %extracted
-/// %ret = linalg.matmul %ret, A
-/// %inserted = insert %ret into ret<n x h x w x f>
+/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
+/// and tileW. After the transformation, we get
+///
+/// scf.for %h = 0 to tileH step 1
+/// scf.for %w = 0 to tileW step 1
+/// scf.for %n = 0 to N step 1
+/// scf.for %f = 0 to F step 1
+/// %extracted = extract %extracted<alphaH x alphaW> from
+/// %input<alphaH x alphaW x tileH x tileW x N x F>
+/// at [0, 0, %h, %w, %n, %f]
+/// %ret = linalg.matmul AT, %extracted
+/// %ret = linalg.matmul %ret, A
+/// %inserted = insert %ret<alphaH x alphaW> into
+/// output<N x H x W x F>
+/// at [%n, (%h x m), (%w x m), %f]
FailureOr<Operation *>
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
linalg::WinogradOutputTransformOp op);
More information about the Mlir-commits
mailing list