[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
Hsiangkai Wang
llvmlistbot at llvm.org
Wed Jul 17 22:48:11 PDT 2024
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96184
>From 72477955477477f1d95e2125cd33122899154f0d Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:44:27 +0100
Subject: [PATCH 1/2] [mlir][linalg] Implement TilingInterface for winograd
operations
In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operations. Before converting winograd
operations into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.
Add a transform operation structured.decompose_winograd_op to decompose
winograd operations. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operations.
---
.../mlir/Dialect/Linalg/IR/LinalgOps.td | 22 +-
.../Linalg/TransformOps/LinalgTransformOps.td | 37 +++
.../Dialect/Linalg/Transforms/Transforms.h | 45 +++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 216 +++++++++++++++
.../TransformOps/LinalgTransformOps.cpp | 41 +++
.../Linalg/Transforms/WinogradConv2D.cpp | 18 ++
.../transform-tile-and-winograd-rewrite.mlir | 260 ++++++++++++++++++
7 files changed, 633 insertions(+), 6 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078..bc2df7082ef8e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,8 +154,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
-def Linalg_WinogradFilterTransformOp :
- Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+ [AllElementTypesMatch<["filter", "output"]>]> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -193,8 +193,13 @@ def Linalg_WinogradFilterTransformOp :
let hasVerifier = 1;
}
-def Linalg_WinogradInputTransformOp :
- Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+ [AllElementTypesMatch<["input", "output"]>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -232,8 +237,13 @@ def Linalg_WinogradInputTransformOp :
let hasVerifier = 1;
}
-def Linalg_WinogradOutputTransformOp :
- Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+ [AllElementTypesMatch<["value", "output"]>,
+ DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index ecc86999006db..106f0d79d9792 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
}
+def DecomposeWinogradOp : Op<Transform_Dialect,
+ "structured.decompose_winograd_op",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Decompose winograd operations. It will convert filter, input and output
+ transform operations into a combination of scf, tensor, and linalg
+ equivalent operations. Before applying this transform operations, users
+ need to tile winograd transform operations into supported sizes.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported. Otherwise, the operation
+ succeeds and returns a handle of the sequence that replaces the original
+ operations.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0c7a8edff222f..8ec698b8c0f59 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1339,6 +1339,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract filter<h x w> from filter<f x h x w x c>
+/// %ret = linalg.matmul G, %extracted
+/// %ret = linalg.matmul %ret, GT
+/// %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract input<h x w> from input<n x h x w x c>
+/// %ret = linalg.matmul BT, %extracted
+/// %ret = linalg.matmul %ret, B
+/// %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+/// scf.for %f = lo_f to hi_f step 1
+/// %extracted = extract input<h x w> from result<h x w x n x f>
+/// %ret = linalg.matmul AT, %extracted
+/// %ret = linalg.matmul %ret, A
+/// %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..c93aebb095784 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2776,6 +2776,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+ Location loc) {
+ if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+ auto intAttr = cast<IntegerAttr>(attr);
+ return builder.create<arith::ConstantOp>(loc, intAttr);
+ }
+ return opFoldResult.get<Value>();
+}
+
LogicalResult WinogradInputTransformOp::verify() {
auto inputType = cast<ShapedType>(getInput().getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2813,6 +2822,113 @@ LogicalResult WinogradInputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ auto indexType = builder.getIndexType();
+ auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+ auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ Value output = getOutput();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(offsets[2]);
+ resultOffsets.push_back(offsets[3]);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(sizes[0]);
+ resultSizes.push_back(sizes[1]);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(sizes[5]);
+
+ return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value input = getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = inputH != 1 ? alpha : 1;
+ int64_t alphaW = inputW != 1 ? alpha : 1;
+ auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ auto context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(mappedOffset1);
+ sliceOffsets.push_back(mappedOffset2);
+ sliceOffsets.push_back(zeroAttr);
+ sliceSizes.push_back(sizes[4]);
+ sliceSizes.push_back(alphaHAttr);
+ sliceSizes.push_back(alphaWAttr);
+ sliceSizes.push_back(sizes[5]);
+ SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+ sliceOffsets.clear();
+ sliceSizes.clear();
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//
@@ -2855,6 +2971,106 @@ LogicalResult WinogradOutputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ auto indexType = builder.getIndexType();
+ auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+ auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ Value value = getValue();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value output = getOutput();
+ auto outputType = cast<ShapedType>(output.getType());
+ auto outputShape = outputType.getShape();
+ int64_t outputH = outputShape[1];
+ int64_t outputW = outputShape[2];
+ int64_t m = getM();
+ auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+ auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+
+ Location loc = getLoc();
+ auto context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(mappedOffset1);
+ resultOffsets.push_back(mappedOffset2);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(heightM);
+ resultSizes.push_back(widthM);
+ resultSizes.push_back(sizes[5]);
+ return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(offsets[2]);
+ sliceOffsets.push_back(offsets[3]);
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(zeroAttr);
+ sliceSizes.push_back(sizes[0]);
+ sliceSizes.push_back(sizes[1]);
+ sliceSizes.push_back(oneAttr);
+ sliceSizes.push_back(oneAttr);
+ sliceSizes.push_back(sizes[4]);
+ sliceSizes.push_back(sizes[5]);
+ SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+
+ sliceOffsets.clear();
+ sliceSizes.clear();
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> strides(4, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b611347b8de2e..e279a6089302e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3748,6 +3748,47 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ FailureOr<Operation *> maybeTransformed = failure();
+ bool supported =
+ TypeSwitch<Operation *, bool>(target)
+ .Case([&](linalg::WinogradFilterTransformOp op) {
+ maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
+ return true;
+ })
+ .Case([&](linalg::WinogradInputTransformOp op) {
+ maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
+ return true;
+ })
+ .Case([&](linalg::WinogradOutputTransformOp op) {
+ maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
+ return true;
+ })
+ .Default([&](Operation *op) { return false; });
+
+ if (!supported) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "this operation is not supported to decompose into other operations";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ if (supported && failed(maybeTransformed)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "decompose Winograd operations failed";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 754f832e98eea..c9a6c906008a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -1170,6 +1170,24 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
return winogradConv2DHelper(rewriter, op, m, r);
}
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op) {
+ return decomposeWinogradFilterTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op) {
+ return decomposeWinogradInputTransformHelper(rewriter, op);
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op) {
+ return decomposeWinogradOutputTransformHelper(rewriter, op);
+}
+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
new file mode 100644
index 0000000000000..29833324bb21b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -0,0 +1,260 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.empty() : tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+ %4 = tensor.empty() : tensor<36x8x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %6 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK: scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK-NEXT: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK-NEXt: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S9:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
+// CHECK: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x6x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+ %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+ tensor.yield %cst : f32
+ } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+ %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+ %4 = tensor.empty() : tensor<36x18x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+ %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
+ tensor.yield %cst : f32
+ } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+ %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+ return %extracted_slice : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
+// CHECK: linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK: ^bb0({{.*}}):
+// CHECK: tensor.yield %[[CST_6]] : f32
+// CHECK: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
+// CHECK: linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK: ^bb0({{.*}}):
+// CHECK: tensor.yield %[[CST_6]] : f32
+// CHECK: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
+// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: %[[S22:.*]] = linalg.mul
+// CHECK: %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
+// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
+// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK: }
+
+// -----
+
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+ %0 = tensor.empty() : tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+ %4 = tensor.empty() : tensor<6x2x2xf32>
+ %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+ %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ return %6 : tensor<2x4x1x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
+// CHECK: %[[S5:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
+// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
+// CHECK: linalg.matmul
+// CHECK: %[[S12:.*]] = linalg.mul
+// CHECK: %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
+// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
+// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>
>From 936356fd49b7f83aa307eb3fdf46348ee151c149 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 17 Jul 2024 18:26:14 +0100
Subject: [PATCH 2/2] Address ftynse's comments
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 122 +++----
.../transform-tile-and-winograd-rewrite.mlir | 332 ++++++++++--------
2 files changed, 226 insertions(+), 228 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c93aebb095784..c02abf692784c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2776,15 +2776,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
-Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
- Location loc) {
- if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
- auto intAttr = cast<IntegerAttr>(attr);
- return builder.create<arith::ConstantOp>(loc, intAttr);
- }
- return opFoldResult.get<Value>();
-}
-
LogicalResult WinogradInputTransformOp::verify() {
auto inputType = cast<ShapedType>(getInput().getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2825,9 +2816,9 @@ LogicalResult WinogradInputTransformOp::verify() {
SmallVector<Range>
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
- auto indexType = builder.getIndexType();
- auto zeroAttr = builder.getIntegerAttr(indexType, 0);
- auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ IndexType indexType = builder.getIndexType();
+ IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+ IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
Value output = getOutput();
SmallVector<Range> loopBounds(6);
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2849,21 +2840,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
- auto zeroAttr = builder.getI64IntegerAttr(0);
- auto oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(offsets[2]);
- resultOffsets.push_back(offsets[3]);
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(zeroAttr);
- resultSizes.push_back(sizes[0]);
- resultSizes.push_back(sizes[1]);
- resultSizes.push_back(oneAttr);
- resultSizes.push_back(oneAttr);
- resultSizes.push_back(sizes[4]);
- resultSizes.push_back(sizes[5]);
+ resultOffsets.append(
+ {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+ resultSizes.append(
+ {sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
return success();
}
@@ -2872,11 +2855,11 @@ FailureOr<TilingResult>
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
- auto oneAttr = builder.getI64IntegerAttr(1);
- auto zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
Value input = getInput();
auto inputType = cast<ShapedType>(input.getType());
- auto inputShape = inputType.getShape();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
int64_t m = getM();
@@ -2884,29 +2867,25 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t alpha = m + r - 1;
int64_t alphaH = inputH != 1 ? alpha : 1;
int64_t alphaW = inputW != 1 ? alpha : 1;
- auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
- auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
Location loc = getLoc();
SmallVector<Value> tiledOperands;
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
- auto context = builder.getContext();
+ MLIRContext *context = builder.getContext();
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(mappedOffset1);
- sliceOffsets.push_back(mappedOffset2);
- sliceOffsets.push_back(zeroAttr);
- sliceSizes.push_back(sizes[4]);
- sliceSizes.push_back(alphaHAttr);
- sliceSizes.push_back(alphaWAttr);
- sliceSizes.push_back(sizes[5]);
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+ sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+ sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
SmallVector<OpFoldResult> inputStrides(4, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
@@ -2921,7 +2900,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
- SmallVector<Type, 4> resultTypes;
+ SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
@@ -2974,9 +2953,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
SmallVector<Range>
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
- auto indexType = builder.getIndexType();
- auto zeroAttr = builder.getIntegerAttr(indexType, 0);
- auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ IndexType indexType = builder.getIndexType();
+ IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+ IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
Value value = getValue();
SmallVector<Range> loopBounds(6);
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2998,57 +2977,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
- auto zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
Value output = getOutput();
auto outputType = cast<ShapedType>(output.getType());
- auto outputShape = outputType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
int64_t outputH = outputShape[1];
int64_t outputW = outputShape[2];
int64_t m = getM();
- auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
- auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+ IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+ IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
Location loc = getLoc();
- auto context = builder.getContext();
+ MLIRContext *context = builder.getContext();
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
- loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
-
- resultOffsets.push_back(zeroAttr);
- resultOffsets.push_back(mappedOffset1);
- resultOffsets.push_back(mappedOffset2);
- resultOffsets.push_back(zeroAttr);
- resultSizes.push_back(sizes[4]);
- resultSizes.push_back(heightM);
- resultSizes.push_back(widthM);
- resultSizes.push_back(sizes[5]);
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+ resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+ resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
return success();
}
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
- auto oneAttr = builder.getI64IntegerAttr(1);
- auto zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
Location loc = getLoc();
SmallVector<Value> tiledOperands;
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(offsets[2]);
- sliceOffsets.push_back(offsets[3]);
- sliceOffsets.push_back(zeroAttr);
- sliceOffsets.push_back(zeroAttr);
- sliceSizes.push_back(sizes[0]);
- sliceSizes.push_back(sizes[1]);
- sliceSizes.push_back(oneAttr);
- sliceSizes.push_back(oneAttr);
- sliceSizes.push_back(sizes[4]);
- sliceSizes.push_back(sizes[5]);
+ sliceOffsets.append(
+ {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+ sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
@@ -3063,7 +3029,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
loc, getOutput(), sliceOffsets, sliceSizes, strides));
- SmallVector<Type, 4> resultTypes;
+ SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 29833324bb21b..6bb3fb1423edc 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -31,52 +31,77 @@ module attributes {transform.with_named_sequence} {
}
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func.func @conv2d
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK: scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG3:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK-NEXT: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG3]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[S5:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S6:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x2x2x2x5xf32>) {
-// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32>
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[S15:.*]] = scf.for {{.*}} = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK-NEXt: scf.for {{.*}} = %[[C0]] to %[[C5]] step %[[C1]] iter_args({{.*}} = %[[ARG8]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x2x2x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S9:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
-// CHECK: %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK: scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK-NEXT: scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) {
-// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32>
-// CHECK: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK-NEXT: scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[S16:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
-// CHECK: %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32>
+// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty()
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: scf.yield %[[S13]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
+// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
+// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S19:.*]] = linalg.matmul
+// CHECK: %[[S20:.*]] = tensor.empty()
+// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_9]]
+// CHECK: scf.yield %[[S15]]
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
// -----
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<6x6x5x2xf32>
%1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
@@ -91,7 +116,7 @@ func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5x
%4 = tensor.empty() : tensor<36x18x2xf32>
%5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
%expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
- %padded_1 = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+ %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] {
^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
tensor.yield %cst : f32
} : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
@@ -117,80 +142,82 @@ module attributes {transform.with_named_sequence} {
}
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x5x2xf32>) {
-// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32>
-// CHECK: linalg.matmul
-// CHECK: %[[S13:.*]] = linalg.matmul
-// CHECK: %[[S14:.*]] = tensor.empty() : tensor<6x6x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> into tensor<6x6x5x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_11]] : tensor<6x6x5x2xf32>
-// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK: ^bb0({{.*}}):
-// CHECK: tensor.yield %[[CST_6]] : f32
-// CHECK: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
-// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x6x3x3x2x5xf32>) {
-// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32>
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x5xf32> to tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<6x6x1x1x2x5xf32>) {
-// CHECK: linalg.matmul
-// CHECK: %[[S17:.*]] = linalg.matmul
-// CHECK: %[[S18:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S17]] into %[[S18]][0, 0, 0, 0, 0, 0] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> into tensor<6x6x1x1x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<6x6x1x1x2x5xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x5xf32> into tensor<6x6x3x3x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S6:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
-// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG3]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
-// CHECK: ^bb0({{.*}}):
-// CHECK: tensor.yield %[[CST_6]] : f32
-// CHECK: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
-// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x12x12x2xf32>) {
-// CHECK: tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6x1x1x2x2xf32>
-// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32>
-// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK: scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) -> (tensor<2x4x4x2xf32>) {
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
-// CHECK: %[[S22:.*]] = linalg.mul
-// CHECK: %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
-// CHECK: %[[INSERTED_SLICE_13:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
-// CHECK: %[[INSERTED_SLICE_14:.*]] = tensor.insert_slice %[[INSERTED_SLICE_13]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_14]] : tensor<2x4x4x2xf32>
-// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
-// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK: }
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty()
+// CHECK: %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = linalg.matmul
+// CHECK: %[[S13:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32>
+// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
+// CHECK: %[[S15:.*]] = linalg.matmul
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S6:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
+// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
+// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
+// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
+// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S17:.*]] = linalg.matmul
+// CHECK: %[[S19:.*]] = linalg.matmul
+// CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x4xf32>
+// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_12]]
+// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32>
+// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
+// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S9]]
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1]
+// CHECK: return %[[EXTRACTED_SLICE]]
// -----
-func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
%0 = tensor.empty() : tensor<6x1x5x2xf32>
%1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
%2 = tensor.empty() : tensor<6x1x1x1x2x5xf32>
@@ -200,7 +227,7 @@ func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>
%4 = tensor.empty() : tensor<6x2x2xf32>
%5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
%expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
- %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
return %6 : tensor<2x4x1x2xf32>
}
@@ -220,41 +247,46 @@ module attributes {transform.with_named_sequence} {
}
}
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func.func @conv2d_mx1_rx1
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C5:.*]] = arith.constant 5 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x5x2xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<1x3x1x1xf32> to tensor<3x1xf32>
-// CHECK: %[[S9:.*]] = linalg.matmul
-// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x1x1xf32> into tensor<6x1x5x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x5x2xf32>
-// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<6x1x1x1x2x5xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<2x6x1x5xf32> to tensor<1x6x1x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 1, 1] [1, 1, 1, 1] : tensor<1x6x1x1xf32> to tensor<6x1xf32>
-// CHECK: %[[S9:.*]] = linalg.matmul
-// CHECK: %[[S10:.*]] = tensor.empty() : tensor<6x1x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[S10]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1xf32> into tensor<6x1x1x1x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> into tensor<6x1x1x1x2x5xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<6x1x1x1x2x5xf32>
-// CHECK: %[[S5:.*]] = linalg.batch_matmul
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
-// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG3]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK: scf.for %[[ARG6:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) -> (tensor<2x4x1x2xf32>) {
-// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG4]], %[[ARG6]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x2x2xf32> to tensor<6x1x1x1x1x1xf32>
-// CHECK: tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x1x1x1x1xf32> to tensor<6x1xf32>
-// CHECK: linalg.matmul
-// CHECK: %[[S12:.*]] = linalg.mul
-// CHECK: %[[S13:.*]] = tensor.empty() : tensor<1x4x1x1xf32>
-// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[S13]][0, 0, 0, 0] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<4x1xf32> into tensor<1x4x1x1xf32>
-// CHECK: %[[INSERTED_SLICE_5:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG7]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 4, 1, 1] [1, 1, 1, 1] : tensor<1x4x1x1xf32> into tensor<2x4x1x2xf32>
-// CHECK: scf.yield %[[INSERTED_SLICE_5]] : tensor<2x4x1x2xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S5:.*]] = linalg.batch_matmul
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2]
+// CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = linalg.matmul
+// CHECK: %[[S10:.*]] = tensor.empty() : tensor<4x1xf32>
+// CHECK: %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: linalg.yield %[[IN]] : f32
+// CHECK: } -> tensor<4x1xf32>
+// CHECK: %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: return %[[S6]]
More information about the Mlir-commits
mailing list