[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
Hsiangkai Wang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jun 26 07:08:31 PDT 2024
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/96182
>From 374b0d5b83ce080bea690199380e270a36ad1c52 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:49:08 +0100
Subject: [PATCH 1/6] [mlir][linalg] Add transform operator for Winograd Conv2D
algorithm
Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++
.../Dialect/Linalg/Transforms/Transforms.h | 7 ++
.../TransformOps/LinalgTransformOps.cpp | 25 ++++++
.../Linalg/Transforms/WinogradConv2D.cpp | 6 ++
.../Linalg/transform-winograd-conv2d.mlir | 88 +++++++++++++++++++
5 files changed, 177 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..68d0f713caad4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
}];
}
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+ "structured.winograd_conv2d",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported. Otherwise, the operation
+ succeeds and returns a handle of the sequence that replaces the original
+ convolution.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ I64Attr:$m,
+ I64Attr:$r);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..da107b66257a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+ linalg::Conv2DNhwcFhwcOp op, int64_t m,
+ int64_t r);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..d051b29e1f06f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+ transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ auto maybeTransformed =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+ return winogradConv2D(rewriter, op, getM(), getR());
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(op, "not supported");
+ });
+
+ if (failed(maybeTransformed))
+ return emitDefaultSilenceableFailure(target);
+
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 86e834d51f2fc..d1f4be8bbf29a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -311,6 +311,12 @@ class WinogradConv2DNhwcFhwc final
} // end anonymous namespace
//===----------------------------------------------------------------------===//
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+ linalg::Conv2DNhwcFhwcOp op, int64_t m,
+ int64_t r) {
+ return winogradConv2DHelper(rewriter, op, m, r);
+}
+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
new file mode 100644
index 0000000000000..1e74fea5a1c31
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.empty() : tensor<2x8x8x2xf32>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x8x8x2xf32>
+ %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %2 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+ %0 = tensor.empty() : tensor<2x9x9x2xf32>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x9x9x2xf32>
+ %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+ return %2 : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x9x9x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
+// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }
>From c94b1a3d2b30eefaa556b8ddf1f4767d89d72fe0 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 09:43:43 +0100
Subject: [PATCH 2/6] Revert "[mlir][linalg] Add transform operator for
Winograd Conv2D algorithm"
This reverts commit 374b0d5b83ce080bea690199380e270a36ad1c52.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 51 -----------
.../Dialect/Linalg/Transforms/Transforms.h | 7 --
.../TransformOps/LinalgTransformOps.cpp | 25 ------
.../Linalg/Transforms/WinogradConv2D.cpp | 6 --
.../Linalg/transform-winograd-conv2d.mlir | 88 -------------------
5 files changed, 177 deletions(-)
delete mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,55 +2587,4 @@ def MapCopyToThreadsOp :
}];
}
-//===----------------------------------------------------------------------===//
-// Winograd Conv2D
-//===----------------------------------------------------------------------===//
-
-def WinogradConv2DOp : Op<Transform_Dialect,
- "structured.winograd_conv2d",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
- let description = [{
- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
- matrix multiply. Before the matrix multiply, it will convert filter and
- input into a format suitable for batched matrix multiply. After the matrix
- multiply, it will convert output to the final result tensor.
-
- The algorithm F(m x m, r x r) is
-
- Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
- The size of output Y is m x m. The size of filter g is r x r. The size of
- input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
- transformation matrices.
-
- #### Return modes:
-
- This operation fails if `target` is unsupported. Otherwise, the operation
- succeeds and returns a handle of the sequence that replaces the original
- convolution.
- }];
-
- let arguments = (ins TransformHandleTypeInterface:$target,
- I64Attr:$m,
- I64Attr:$r);
- let results = (outs TransformHandleTypeInterface:$transformed);
-
- let assemblyFormat =
- "$target attr-dict `:` functional-type($target, results)";
-
- let builders = [
- OpBuilder<(ins "Value":$target)>
- ];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-}
-
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index da107b66257a5..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,13 +1312,6 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);
-/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
-/// F(m x m, r x r). m is the dimension size of output and r is the dimension
-/// size of filter.
-FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
- linalg::Conv2DNhwcFhwcOp op, int64_t m,
- int64_t r);
-
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d051b29e1f06f..bc02788f9c441 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,31 +3480,6 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
-//===----------------------------------------------------------------------===//
-// WinogradConv2DOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
- transform::TransformRewriter &rewriter, linalg::LinalgOp target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- rewriter.setInsertionPoint(target);
- auto maybeTransformed =
- TypeSwitch<Operation *, FailureOr<Operation *>>(target)
- .Case([&](linalg::Conv2DNhwcFhwcOp op) {
- return winogradConv2D(rewriter, op, getM(), getR());
- })
- .Default([&](Operation *op) {
- return rewriter.notifyMatchFailure(op, "not supported");
- });
-
- if (failed(maybeTransformed))
- return emitDefaultSilenceableFailure(target);
-
- results.push_back(*maybeTransformed);
- return DiagnosedSilenceableFailure::success();
-}
-
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index d1f4be8bbf29a..86e834d51f2fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -311,12 +311,6 @@ class WinogradConv2DNhwcFhwc final
} // end anonymous namespace
//===----------------------------------------------------------------------===//
-FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
- linalg::Conv2DNhwcFhwcOp op, int64_t m,
- int64_t r) {
- return winogradConv2DHelper(rewriter, op, m, r);
-}
-
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
deleted file mode 100644
index 1e74fea5a1c31..0000000000000
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ /dev/null
@@ -1,88 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
-
-func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
- %0 = tensor.empty() : tensor<2x8x8x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x8x8x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
- return %2 : tensor<2x8x8x2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
- transform.yield
- }
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
- %0 = tensor.empty() : tensor<2x9x9x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x9x9x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
- return %2 : tensor<2x9x9x2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
- transform.yield
- }
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }
>From 5a391881394094bfd747cb97bf023ed3df06923e Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 09:44:19 +0100
Subject: [PATCH 3/6] Revert "[mlir][linalg] Implement Conv2D using Winograd
Conv2D algorithm"
This reverts commit 4240341b4f06f1b77f63b0f619cae3804d88eb68.
---
.../mlir/Dialect/Linalg/IR/LinalgOps.td | 114 -------
.../Dialect/Linalg/Transforms/Transforms.h | 4 -
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 78 -----
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 -
.../Linalg/Transforms/WinogradConv2D.cpp | 321 ------------------
mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 248 --------------
.../Dialect/Linalg/TestLinalgTransforms.cpp | 13 -
7 files changed, 779 deletions(-)
delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
delete mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index de1097b6ac27b..64c538367267d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,118 +154,4 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
-def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
- let summary = "Winograd filter transform operator";
- let description = [{
- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
- matrix multiply. Before the matrix multiply, it will convert filter and
- input into a format suitable for batched matrix multiply. After the matrix
- multiply, it will convert output to the final result tensor.
-
- The algorithm F(m x m, r x r) is
-
- Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
- The size of output Y is m x m. The size of filter g is r x r. The size of
- input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
- transformation matrices.
-
- This operator is defined to represent the high level concept of filter
- transformation (G x g x G^T) in the Winograd Conv2D algorithm.
- }];
-
- let arguments = (ins AnyRankedTensor:$filter,
- AnyRankedTensor:$output,
- I64Attr:$m,
- I64Attr:$r
- );
-
- let results = (outs AnyRankedTensor:$result);
- let assemblyFormat = [{
- attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
- `ins` `(` $filter `:` type($filter) `)`
- `outs` `(` $output `:` type($output) `)`
- `->` type($result)
- }];
- let hasVerifier = 1;
-}
-
-def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
- let summary = "Winograd input transform operator";
- let description = [{
- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
- matrix multiply. Before the matrix multiply, it will convert filter and
- input into a format suitable for batched matrix multiply. After the matrix
- multiply, it will convert output to the final result tensor.
-
- The algorithm F(m x m, r x r) is
-
- Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
- The size of output Y is m x m. The size of filter g is r x r. The size of
- input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
- transformation matrices.
-
- This operator is defined to represent the high level concept of input
- transformation (B^T x d x B) in the Winograd Conv2D algorithm.
- }];
-
- let arguments = (ins AnyRankedTensor:$input,
- AnyRankedTensor:$output,
- I64Attr:$m,
- I64Attr:$r
- );
-
- let results = (outs AnyRankedTensor:$result);
- let assemblyFormat = [{
- attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
- `ins` `(` $input `:` type($input) `)`
- `outs` `(` $output `:` type($output) `)`
- `->` type($result)
- }];
- let hasVerifier = 1;
-}
-
-def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
- let summary = "Winograd output transform operator";
- let description = [{
- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
- matrix multiply. Before the matrix multiply, it will convert filter and
- input into a format suitable for batched matrix multiply. After the matrix
- multiply, it will convert output to the final result tensor.
-
- The algorithm F(m x m, r x r) is
-
- Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
-
- The size of output Y is m x m. The size of filter g is r x r. The size of
- input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
- transformation matrices.
-
- This operator is defined to represent the high level concept of output
- transformation (A^T x y x A) in the Winograd Conv2D algorithm.
- }];
-
- let arguments = (ins AnyRankedTensor:$value,
- AnyRankedTensor:$output,
- I64Attr:$m,
- I64Attr:$r
- );
-
- let results = (outs AnyRankedTensor:$result);
- let assemblyFormat = [{
- attr-dict
- `m` `(` $m `)`
- `r` `(` $r `)`
- `ins` `(` $value `:` type($value) `)`
- `outs` `(` $output `:` type($output) `)`
- `->` type($result)
- }];
- let hasVerifier = 1;
-}
-
#endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..05e97befdec1f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,10 +1692,6 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
-/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
- int64_t r);
-
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7bf2a5bca037f..57d126603ebd7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,84 +2734,6 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>{result};
}
-//===----------------------------------------------------------------------===//
-// WinogradFilterTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradFilterTransformOp::verify() {
- auto filterType = cast<ShapedType>(getFilter().getType());
- auto outputType = cast<ShapedType>(getOutput().getType());
- auto filterElemType = filterType.getElementType();
- auto outputElemType = outputType.getElementType();
- if (filterElemType != outputElemType) {
- return emitOpError() << "expected element type of input " << filterElemType
- << " to match element type of output "
- << outputElemType;
- }
-
- unsigned filterRank = filterType.getRank();
- if (filterRank != 4)
- return emitOpError() << "expected rank of input is 4";
-
- unsigned outputRank = outputType.getRank();
- if (outputRank != 6)
- return emitOpError() << "expected rank of output is 6";
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// WinogradInputTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradInputTransformOp::verify() {
- auto inputType = cast<ShapedType>(getInput().getType());
- auto outputType = cast<ShapedType>(getOutput().getType());
- auto inputElemType = inputType.getElementType();
- auto outputElemType = outputType.getElementType();
- if (inputElemType != outputElemType) {
- return emitOpError() << "expected element type of input " << inputElemType
- << " to match element type of output "
- << outputElemType;
- }
-
- unsigned inputRank = inputType.getRank();
- if (inputRank != 4)
- return emitOpError() << "expected rank of input is 4";
-
- unsigned outputRank = outputType.getRank();
- if (outputRank != 6)
- return emitOpError() << "expected rank of output is 6";
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// WinogradOutputTransformOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult WinogradOutputTransformOp::verify() {
- auto valueType = cast<ShapedType>(getValue().getType());
- auto outputType = cast<ShapedType>(getOutput().getType());
- auto valueElemType = valueType.getElementType();
- auto outputElemType = outputType.getElementType();
- if (valueElemType != outputElemType) {
- return emitOpError() << "expected element type of value " << valueElemType
- << " to match element type of output "
- << outputElemType;
- }
-
- unsigned valueRank = valueType.getRank();
- if (valueRank != 6)
- return emitOpError() << "expected rank of input is 6";
-
- unsigned outputRank = outputType.getRank();
- if (outputRank != 4)
- return emitOpError() << "expected rank of output is 4";
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a7dcc29b5b9be..7e3dc56e0acdc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,7 +38,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Transforms.cpp
TransposeConv2D.cpp
Vectorization.cpp
- WinogradConv2D.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
deleted file mode 100644
index 86e834d51f2fc..0000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ /dev/null
@@ -1,321 +0,0 @@
-//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Implement Winograd Conv2D algorithm. The implementation is based on the
-// paper: Fast Algorithms for Convolutional Neural Networks
-// (https://arxiv.org/abs/1509.09308)
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir {
-namespace linalg {
-
-namespace {
-
-using TransformMapKeyTy = std::pair<int, int>;
-
-// We use F(m, r) to define the size of minimal filtering algorithms.
-// m is the output dimension and r is the filter dimension. We can get
-// the input dimension, alpha, from the formula, alpha = m + r - 1.
-//
-// For example, when m = 2 and r = 3, we know its input size is 4.
-// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
-Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
- auto type = cast<ShapedType>(data.getType());
- auto elementType = type.getElementType();
- auto shape = type.getShape();
- auto collapseType = RankedTensorType::get(
- {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
- elementType);
- SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
- return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
- reassociation);
-}
-
-// This function generates linalg.batch_matmul to multiply input with filter.
-// linalg.batch_matmul only supports 3-dimension data sets. We can treat
-// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
-// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
-// can convert 6-dimension input data to 3-dimension representation that is
-// suitable for linalg.batch_matmul.
-//
-// Batched matmul will do the matrix multiply with the reduction on channel.
-//
-// We get
-//
-// %collapsed_input = tensor.collapse_shape %input
-// %collapsed_filter = tensor.collapse_shape %filter
-// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
-// %expanded_ret = tensor.expand_shape %ret
-//
-// After this function, we get return value with data layout
-// (tileH, tileW, H, W, N, F).
-Value matrixMultiply(RewriterBase &rewriter, Location loc,
- Value transformedFilter, Value transformedInput) {
- auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
- auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
-
- // Batched matrix multiply
- auto filterType = cast<ShapedType>(transformedFilter.getType());
- auto filterShape = filterType.getShape();
- auto inputType = cast<ShapedType>(transformedInput.getType());
- auto inputElemType = inputType.getElementType();
- auto inputShape = inputType.getShape();
-
- auto matmulType = RankedTensorType::get(
- {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
- inputShape[4], filterShape[5]},
- inputElemType);
- Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
- inputElemType);
-
- auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
- loc, matmulType, ValueRange({collapseInput, collapseFilter}),
- ValueRange{init});
-
- // Expand matmul result
- SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
- auto expandType =
- RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
- inputShape[3], inputShape[4], filterShape[5]},
- inputElemType);
- auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
- loc, expandType, matmulOp.getResult(0), reassociation);
- return expandOutput;
-}
-
-Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
- RankedTensorType alignedType) {
- Value alignedInput = rewriter.create<tensor::EmptyOp>(
- loc, alignedType.getShape(), alignedType.getElementType());
-
- auto zeroIndex = rewriter.getIndexAttr(0);
- auto oneIndex = rewriter.getIndexAttr(1);
- SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
- SmallVector<OpFoldResult, 4> strides(4, oneIndex);
-
- auto valueType = cast<ShapedType>(value.getType());
- auto valueShape = valueType.getShape();
- SmallVector<OpFoldResult, 4> sizes;
- sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
- sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
- sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
- sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
-
- return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
- offsets, sizes, strides);
-}
-
-Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
- Value value, RankedTensorType extractedType) {
- auto zeroIndex = rewriter.getIndexAttr(0);
- auto oneIndex = rewriter.getIndexAttr(1);
- SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
- SmallVector<OpFoldResult, 4> strides(4, oneIndex);
-
- auto extractedShape = extractedType.getShape();
- SmallVector<OpFoldResult, 4> sizes;
- sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
- sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
- sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
- sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
-
- return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
- offsets, sizes, strides);
-}
-
-bool hasAllOneValues(DenseIntElementsAttr attr) {
- return llvm::all_of(
- attr, [](const APInt &element) { return element.getSExtValue() == 1; });
-}
-
-FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
- linalg::Conv2DNhwcFhwcOp convOp,
- int64_t m, int64_t r) {
- Value input = convOp.getInputs()[0];
- Value filter = convOp.getInputs()[1];
- Value output = convOp.getOutputs()[0];
- auto inputType = cast<ShapedType>(input.getType());
- auto filterType = cast<ShapedType>(filter.getType());
- auto outputType = cast<ShapedType>(output.getType());
-
- if (!inputType.hasStaticShape())
- return rewriter.notifyMatchFailure(convOp,
- "expected a static shape for the input");
-
- if (!filterType.hasStaticShape())
- return rewriter.notifyMatchFailure(
- convOp, "expected a static shape for the filter");
-
- if (!hasAllOneValues(convOp.getDilations()))
- return rewriter.notifyMatchFailure(convOp,
- "expected all ones for dilations");
-
- if (!hasAllOneValues(convOp.getStrides()))
- return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
-
- auto filterShape = filterType.getShape();
- int64_t filterF = filterShape[0];
- int64_t filterH = filterShape[1];
- int64_t filterW = filterShape[2];
- int64_t filterC = filterShape[3];
- auto inputShape = inputType.getShape();
- int64_t inputN = inputShape[0];
- int64_t inputH = inputShape[1];
- int64_t inputW = inputShape[2];
- int64_t inputC = inputShape[3];
- auto outputShape = outputType.getShape();
- int64_t outputN = outputShape[0];
- int64_t outputH = outputShape[1];
- int64_t outputW = outputShape[2];
- int64_t outputF = outputShape[3];
-
- // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
- bool isSupportedFilter = false;
- if (filterH == filterW && filterH == r)
- isSupportedFilter = true;
- if (filterH == r && filterW == 1)
- isSupportedFilter = true;
- if (filterH == 1 && filterW == r)
- isSupportedFilter = true;
-
- if (!isSupportedFilter)
- return rewriter.notifyMatchFailure(
- convOp, "only support filter (r x r), (r x 1) or (1 x r)");
-
- // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5)
- static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
- F_2_3, F_4_3, F_2_5};
-
- TransformMapKeyTy key = {m, r};
- auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
- // If we cannot find the constant transformation matrix, it means we do
- // not support this configuration yet.
- if (it == validConfigs.end())
- return failure();
-
- // All the criterias are satisfied. We can do Winograd Conv2D.
- Location loc = convOp.getLoc();
-
- // For F(m x 1, r x 1), we only need to do left side transform.
- bool leftTransform = filterH != 1;
- // For F(1 x m, 1 x r), we only need to do right side transform.
- bool rightTransform = filterW != 1;
- int64_t heightM = leftTransform ? m : 1;
- int64_t widthM = rightTransform ? m : 1;
- int64_t heightR = leftTransform ? r : 1;
- int64_t widthR = rightTransform ? r : 1;
-
- // --- Create operator for filter transform ---
- Type elementType = filterType.getElementType();
- int64_t alphaH = heightM + heightR - 1;
- int64_t alphaW = widthM + widthR - 1;
- int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
- int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
- auto retType = RankedTensorType::get(
- {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType);
- Value retValue =
- rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
- auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
- loc, retType, filter, retValue, m, r);
-
- // --- Create operator for input transform ---
-
- // When input size - (r - 1) is not aligned with output tile size, we need to
- // pad the input data to create the full tiles as tiling.
- int64_t alignedInputH = tileH * heightM + (heightR - 1);
- int64_t alignedInputW = tileW * widthM + (widthR - 1);
- if (alignedInputH != inputH || alignedInputW != inputW) {
- auto alignedInputType = RankedTensorType::get(
- {inputN, alignedInputH, alignedInputW, inputC}, elementType);
- input = insertToAlignedTensor(rewriter, loc, input, alignedInputType);
- }
-
- retType = RankedTensorType::get(
- {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType);
- retValue =
- rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
- auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
- loc, retType, input, retValue, m, r);
-
- Value matmulRet =
- matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
-
- // --- Create operator for output transform ---
-
- // When output size is not aligned with output tile size, we need to pad the
- // output buffer to insert the full tiles after tiling.
- int64_t alignedOutputH = tileH * heightM;
- int64_t alignedOutputW = tileW * widthM;
- bool isOutputUnaligned =
- ((alignedOutputH != outputH) || (alignedOutputW != outputW));
- if (isOutputUnaligned) {
- auto alignedOutputType = RankedTensorType::get(
- {outputN, alignedOutputH, alignedOutputW, outputF}, elementType);
- output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType);
- outputType = alignedOutputType;
- }
-
- Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
- loc, outputType, matmulRet, output, m, r);
-
- // When output size is not aligned with output tile size, extract the
- // value from the padded buffer.
- if (isOutputUnaligned) {
- transformedOutput = extractFromAlignedTensor(
- rewriter, loc, transformedOutput,
- RankedTensorType::get({outputN, outputH, outputW, outputF},
- elementType));
- }
-
- rewriter.replaceOp(convOp, transformedOutput);
-
- return transformedOutput.getDefiningOp();
-}
-
-class WinogradConv2DNhwcFhwc final
- : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
- : OpRewritePattern(context), m(m), r(r) {}
-
- LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
- PatternRewriter &rewriter) const override {
- if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
- return failure();
-
- return success();
- }
-
-private:
- int64_t m;
- int64_t r;
-};
-} // end anonymous namespace
-
-//===----------------------------------------------------------------------===//
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
- int64_t r) {
- MLIRContext *context = patterns.getContext();
- patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
-}
-
-} // end namespace linalg
-} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
deleted file mode 100644
index 6cca3c602d4c0..0000000000000
--- a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
+++ /dev/null
@@ -1,248 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
-
-func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
- %0 = tensor.empty() : tensor<2x4x4x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x4x4x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
- return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_4x4_3x3
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x4x4x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
- %0 = tensor.empty() : tensor<2x2x2x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x2x2x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
- return %2 : tensor<2x2x2x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_2x2_5x5
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x2x2x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
- %0 = tensor.empty() : tensor<2x1x4x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x1x4x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
- return %2 : tensor<2x1x4x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_1x4_1x3
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x1x4x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
- %0 = tensor.empty() : tensor<2x4x1x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x4x1x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
- return %2 : tensor<2x4x1x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_4x1_3x1
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x4x1x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
- %0 = tensor.empty() : tensor<2x8x8x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x8x8x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
- return %2 : tensor<2x8x8x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_aligned
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
- %0 = tensor.empty() : tensor<2x9x9x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x9x9x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
- return %2 : tensor<2x9x9x2xf32>
-}
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }
-
-// -----
-
-func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
- %0 = tensor.empty() : tensor<2x4x4x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x4x4x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
- return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_1
-// CHECK: linalg.conv_2d_nhwc_fhwc
-
-// -----
-
-func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> {
- %0 = tensor.empty() : tensor<2x4x4x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x4x4x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
- return %2 : tensor<2x4x4x2xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_2
-// CHECK: linalg.conv_2d_nhwc_fhwc
-
-// -----
-
-func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
- %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
- return %0 : tensor<?x?x?x?xf32>
-}
-
-// CHECK-LABEL: conv2d_unsupported_3
-// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 12cb46a5968f1..4892fa2f99a7c 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,10 +123,6 @@ struct TestLinalgTransforms
*this, "test-erase-unnecessary-inputs",
llvm::cl::desc("Test patterns to erase unnecessary inputs"),
llvm::cl::init(false)};
- Option<bool> testWinogradConv2D{
- *this, "test-winograd-conv2d",
- llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
- llvm::cl::init(false)};
};
} // namespace
@@ -211,13 +207,6 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
-static void applyWinogradConv2D(func::FuncOp funcOp) {
- RewritePatternSet patterns(funcOp.getContext());
- populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
- populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
if (testPatterns)
@@ -242,8 +231,6 @@ void TestLinalgTransforms::runOnOperation() {
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
if (testEraseUnnecessaryInputs)
return applyEraseUnnecessaryInputs(getOperation());
- if (testWinogradConv2D)
- return applyWinogradConv2D(getOperation());
}
namespace mlir {
>From 690662771c806a2f7301bdc4dedc983047c41d35 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:24:07 +0100
Subject: [PATCH 4/6] [mlir][linalg] Implement Conv2D using Winograd Conv2D
algorithm
Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.
The formula of Winograd Conv2D algorithm is
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A
The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
---
.../mlir/Dialect/Linalg/IR/LinalgOps.td | 117 ++++++
.../Dialect/Linalg/Transforms/Transforms.h | 4 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 107 ++++++
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Linalg/Transforms/WinogradConv2D.cpp | 334 ++++++++++++++++++
mlir/test/Dialect/Linalg/winograd-conv2d.mlir | 193 ++++++++++
.../Dialect/Linalg/TestLinalgTransforms.cpp | 13 +
7 files changed, 769 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
create mode 100644 mlir/test/Dialect/Linalg/winograd-conv2d.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..a9007c8db3078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
+def Linalg_WinogradFilterTransformOp :
+ Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
+ let summary = "Winograd filter transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of filter
+ transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
+ TensorRankOf<[AnyType], [4]>:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs TensorRankOf<[AnyType], [4]>:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $filter `:` type($filter) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp :
+ Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
+ let summary = "Winograd input transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of input
+ transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
+ TensorRankOf<[AnyType], [6]>:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs TensorRankOf<[AnyType], [6]>:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $input `:` type($input) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp :
+ Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
+ let summary = "Winograd output transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of output
+ transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
+ TensorRankOf<[AnyType], [4]>:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs TensorRankOf<[AnyType], [4]>:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $value `:` type($value) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..835aeaf2ffed3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+ int64_t r);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..1283315f2eaef 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,113 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>{result};
}
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+ auto filterType = cast<ShapedType>(getFilter().getType());
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t r = getR();
+
+ if (filterH != r && filterH != 1)
+ return failure();
+ if (filterW != r && filterW != 1)
+ return failure();
+ if (filterH == 1 && filterW == 1)
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+ auto inputType = cast<ShapedType>(getInput().getType());
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t outputH = outputShape[0];
+ int64_t outputW = outputShape[1];
+ int64_t outputTileH = outputShape[2];
+ int64_t outputTileW = outputShape[3];
+ int m = getM();
+ int r = getR();
+ bool leftTransform = inputH != 1;
+ bool rightTransform = inputW != 1;
+
+ if (!leftTransform && !rightTransform)
+ return failure();
+
+ if (leftTransform) {
+ int64_t tileH = (inputH - (r - 1)) / m;
+ if (inputH != tileH * m + (r - 1))
+ return failure();
+ if (tileH != outputTileH)
+ return failure();
+ if (outputH != m + r - 1)
+ return failure();
+ }
+
+ if (rightTransform) {
+ int64_t tileW = (inputW - (r - 1)) / m;
+ if (inputW != tileW * m + (r - 1))
+ return failure();
+ if (tileW != outputTileW)
+ return failure();
+ if (outputW != m + r - 1)
+ return failure();
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+ auto valueType = cast<ShapedType>(getValue().getType());
+ ArrayRef<int64_t> valueShape = valueType.getShape();
+ int64_t valueH = valueShape[0];
+ int64_t valueW = valueShape[1];
+ int64_t valueTileH = valueShape[2];
+ int64_t valueTileW = valueShape[3];
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t outputH = outputShape[1];
+ int64_t outputW = outputShape[2];
+ int m = getM();
+ int r = getR();
+ bool leftTransform = valueH != 1;
+ bool rightTransform = valueW != 1;
+
+ if (!leftTransform && !rightTransform)
+ return failure();
+
+ if (leftTransform) {
+ if (valueH != m + r - 1)
+ return failure();
+ if (outputH != m * valueTileH)
+ return failure();
+ }
+
+ if (rightTransform) {
+ if (valueW != m + r - 1)
+ return failure();
+ if (outputW != m * valueTileW)
+ return failure();
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Transforms.cpp
TransposeConv2D.cpp
Vectorization.cpp
+ WinogradConv2D.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..6b46f9e07abf8
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,334 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+/// This function generates linalg.batch_matmul to multiply input with filter.
+/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
+/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
+/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
+/// way, we can convert 6-dimensional inputs to 3-dimensional representation
+/// that is suitable for linalg.batch_matmul.
+///
+/// Batched matmul will do the matrix multiply with the reduction on channel.
+///
+/// We get
+///
+/// %collapsed_input = tensor.collapse_shape %input
+/// %collapsed_filter = tensor.collapse_shape %filter
+/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+/// %expanded_ret = tensor.expand_shape %ret
+///
+/// After this function, we get return value with data layout
+/// (tileH, tileW, H, W, N, F).
+static Value matrixMultiply(RewriterBase &rewriter, Location loc,
+ Value transformedFilter, Value transformedInput,
+ Type outputElementType) {
+ // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
+ auto filterType = cast<ShapedType>(transformedFilter.getType());
+ assert(filterType.hasStaticShape() && "only support static shapes.");
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ Type filterElementType = filterType.getElementType();
+ auto filterReassocType = RankedTensorType::get(
+ {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
+ filterElementType);
+ SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
+ Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, filterReassocType, transformedFilter, filterReassoc);
+
+ // Convert (alphaH, alphaW, tileH, tileW, N, C) to
+ // (alphaH x alphaW, tileH x tileW x N, C) for input.
+ auto inputType = cast<ShapedType>(transformedInput.getType());
+ assert(inputType.hasStaticShape() && "only support static shapes.");
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ Type inputElementType = inputType.getElementType();
+ auto inputReassocType = RankedTensorType::get(
+ {inputShape[0] * inputShape[1],
+ inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
+ inputElementType);
+ SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+ Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, inputReassocType, transformedInput, inputReassoc);
+
+ // Batched matrix multiply.
+ auto matmulType = RankedTensorType::get(
+ {inputShape[0] * inputShape[1],
+ inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
+ outputElementType);
+ Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ outputElementType);
+
+ auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+ ValueRange{init});
+
+ // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
+ // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
+ SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+ auto outputReassocType =
+ RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+ inputShape[3], inputShape[4], filterShape[3]},
+ outputElementType);
+ auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+ loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
+ return expandOutput;
+}
+
+/// Create an empty tensor with alignedType and insert the value into the
+/// created empty tensor with aligned size.
+static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
+ Value value,
+ ArrayRef<int64_t> alignedShape) {
+ OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+ auto valueType = cast<ShapedType>(value.getType());
+ Type elementType = valueType.getElementType();
+ ArrayRef<int64_t> valueShape = valueType.getShape();
+ SmallVector<OpFoldResult, 6> lowIndices(alignedShape.size(), zeroIndex);
+ SmallVector<OpFoldResult, 6> highIndices;
+ for (unsigned i = 0; i < alignedShape.size(); ++i) {
+ highIndices.emplace_back(
+ rewriter.getIndexAttr(alignedShape[i] - valueShape[i]));
+ }
+ auto alignedType = RankedTensorType::get(alignedShape, elementType);
+ Value pad_value = rewriter.create<arith::ConstantOp>(
+ loc, elementType, rewriter.getZeroAttr(elementType));
+ return rewriter.create<tensor::PadOp>(loc, alignedType, value, lowIndices,
+ highIndices, pad_value);
+}
+
+/// Extract sub-tensor with extractedType from value.
+static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+ Value value,
+ RankedTensorType extractedType) {
+ OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+ OpFoldResult oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+ ArrayRef<int64_t> extractedShape = extractedType.getShape();
+ SmallVector<OpFoldResult> sizes =
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
+
+ return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
+ offsets, sizes, strides);
+}
+
+/// Utility function to check all values in the attribute are 1.
+static bool hasAllOneValues(DenseIntElementsAttr attr) {
+ return llvm::all_of(
+ attr, [](const APInt &element) { return element.getSExtValue() == 1; });
+}
+
+/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
+/// linalg.winograd_*_transform ops.
+static FailureOr<Operation *>
+winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
+ int64_t m, int64_t r) {
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+ auto inputType = cast<ShapedType>(input.getType());
+ auto filterType = cast<ShapedType>(filter.getType());
+ auto outputType = cast<ShapedType>(output.getType());
+
+ if (!inputType.hasStaticShape())
+ return rewriter.notifyMatchFailure(convOp,
+ "expected a static shape for the input");
+
+ if (!filterType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected a static shape for the filter");
+
+ if (!hasAllOneValues(convOp.getDilations()))
+ return rewriter.notifyMatchFailure(convOp,
+ "expected all ones for dilations");
+
+ if (!hasAllOneValues(convOp.getStrides()))
+ return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
+
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ int64_t filterF = filterShape[0];
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t filterC = filterShape[3];
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputN = inputShape[0];
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t inputC = inputShape[3];
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t outputN = outputShape[0];
+ int64_t outputH = outputShape[1];
+ int64_t outputW = outputShape[2];
+ int64_t outputF = outputShape[3];
+
+ // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
+ bool isSupportedFilter = false;
+ if (filterH == filterW && filterH == r)
+ isSupportedFilter = true;
+ if (filterH == r && filterW == 1)
+ isSupportedFilter = true;
+ if (filterH == 1 && filterW == r)
+ isSupportedFilter = true;
+
+ if (!isSupportedFilter)
+ return rewriter.notifyMatchFailure(
+ convOp, "only support filter (r x r), (r x 1) or (1 x r)");
+
+ // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
+ static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+ F_2_3, F_4_3, F_2_5};
+
+ TransformMapKeyTy key = {m, r};
+ auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+ // If we cannot find the constant transformation matrix, it means we do
+ // not support this configuration yet.
+ if (it == validConfigs.end())
+ return failure();
+
+ // All the criterias are satisfied. We can do Winograd Conv2D.
+ Location loc = convOp.getLoc();
+
+ // For F(m x 1, r x 1), we only need to do left side transform.
+ bool leftTransform = filterH != 1;
+ // For F(1 x m, 1 x r), we only need to do right side transform.
+ bool rightTransform = filterW != 1;
+ int64_t heightM = leftTransform ? m : 1;
+ int64_t widthM = rightTransform ? m : 1;
+ int64_t heightR = leftTransform ? r : 1;
+ int64_t widthR = rightTransform ? r : 1;
+
+ // --- Create operation for filter transform ---
+ Type filterElementType = filterType.getElementType();
+ int64_t alphaH = heightM + heightR - 1;
+ int64_t alphaW = widthM + widthR - 1;
+ int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
+ int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
+ auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
+ filterElementType);
+ Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+ filterElementType);
+ auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
+ loc, retType, filter, retValue, m, r);
+
+ // --- Create operation for input transform ---
+
+ // When input size - (r - 1) is not aligned with output tile size, we need to
+ // pad the input data to create the full tiles as tiling.
+ Type inputElementType = inputType.getElementType();
+ int64_t alignedInputH = tileH * heightM + (heightR - 1);
+ int64_t alignedInputW = tileW * widthM + (widthR - 1);
+ if (alignedInputH != inputH || alignedInputW != inputW) {
+ input = insertToAlignedTensor(
+ rewriter, loc, input, {inputN, alignedInputH, alignedInputW, inputC});
+ }
+
+ retType = RankedTensorType::get(
+ {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
+ retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
+ inputElementType);
+ auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
+ loc, retType, input, retValue, m, r);
+
+ Type outputElementType = outputType.getElementType();
+ Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
+ transformedInput, outputElementType);
+
+ // --- Create operation for output transform ---
+
+ // When output size is not aligned with output tile size, we need to pad the
+ // output buffer to insert the full tiles after tiling.
+ int64_t alignedOutputH = tileH * heightM;
+ int64_t alignedOutputW = tileW * widthM;
+ bool isOutputUnaligned =
+ ((alignedOutputH != outputH) || (alignedOutputW != outputW));
+ if (isOutputUnaligned) {
+ auto alignedOutputType = RankedTensorType::get(
+ {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
+ output = insertToAlignedTensor(rewriter, loc, output,
+ alignedOutputType.getShape());
+ outputType = alignedOutputType;
+ }
+
+ Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
+ loc, outputType, matmulRet, output, m, r);
+
+ // When output size is not aligned with output tile size, extract the
+ // value from the padded buffer.
+ if (isOutputUnaligned) {
+ transformedOutput = extractFromAlignedTensor(
+ rewriter, loc, transformedOutput,
+ RankedTensorType::get({outputN, outputH, outputW, outputF},
+ outputElementType));
+ }
+
+ rewriter.replaceOp(convOp, transformedOutput);
+
+ return transformedOutput.getDefiningOp();
+}
+
+/// A rewrite pattern for Winograd Conv2D algorithm.
+class WinogradConv2DNhwcFhwc final
+ : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
+ : OpRewritePattern(context), m(m), r(r) {}
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
+ return failure();
+
+ return success();
+ }
+
+private:
+ int64_t m;
+ int64_t r;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+ int64_t r) {
+ MLIRContext *context = patterns.getContext();
+ patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
+}
+
+} // end namespace linalg
+} // end namespace mlir
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
new file mode 100644
index 0000000000000..ec11a6ef8fbee
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir
@@ -0,0 +1,193 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s
+
+func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_4x4_3x3
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%out : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+ return %0 : tensor<2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_2x2_5x5
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
+// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%out : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+ return %0 : tensor<2x1x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_1x4_1x3
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> {
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32>
+// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%out : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+ return %0 : tensor<2x4x1x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_4x1_3x1
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> {
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32>
+// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %0 : tensor<2x8x8x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_aligned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+ return %0 : tensor<2x9x9x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32>
+// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT: ^bb0
+// CHECK-NEXT: tensor.yield %[[CST]] : f32
+// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32>
+// CHECK-NEXT: %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] {
+// CHECK-NEXT: ^bb0
+// CHECK-NEXT: tensor.yield %[[CST]] : f32
+// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf16>, tensor<2x3x3x5xf16>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @conv2d_type_promotion
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+// CHECK-NEXT: return %[[S6]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_1
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ return %0 : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_2
+// CHECK: linalg.conv_2d_nhwc_fhwc
+
+// -----
+
+func.func @conv2d_unsupported_3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<2x3x3x5xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: conv2d_unsupported_3
+// CHECK: linalg.conv_2d_nhwc_fhwc
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4892fa2f99a7c..12cb46a5968f1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -123,6 +123,10 @@ struct TestLinalgTransforms
*this, "test-erase-unnecessary-inputs",
llvm::cl::desc("Test patterns to erase unnecessary inputs"),
llvm::cl::init(false)};
+ Option<bool> testWinogradConv2D{
+ *this, "test-winograd-conv2d",
+ llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyWinogradConv2D(func::FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
+ populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
if (testPatterns)
@@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() {
return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
if (testEraseUnnecessaryInputs)
return applyEraseUnnecessaryInputs(getOperation());
+ if (testWinogradConv2D)
+ return applyWinogradConv2D(getOperation());
}
namespace mlir {
>From bb8087930cfd79a3d4ebf6a8e959f4c30bb70fcf Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 17 Jun 2024 11:49:08 +0100
Subject: [PATCH 5/6] [mlir][linalg] Add transform operator for Winograd Conv2D
algorithm
Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++
.../Dialect/Linalg/Transforms/Transforms.h | 7 ++
.../TransformOps/LinalgTransformOps.cpp | 25 ++++++
.../Linalg/Transforms/WinogradConv2D.cpp | 6 ++
.../Linalg/transform-winograd-conv2d.mlir | 88 +++++++++++++++++++
5 files changed, 177 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..68d0f713caad4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
}];
}
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+ "structured.winograd_conv2d",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported. Otherwise, the operation
+ succeeds and returns a handle of the sequence that replaces the original
+ convolution.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ I64Attr:$m,
+ I64Attr:$r);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 835aeaf2ffed3..da107b66257a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+ linalg::Conv2DNhwcFhwcOp op, int64_t m,
+ int64_t r);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..d051b29e1f06f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+ transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ auto maybeTransformed =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+ return winogradConv2D(rewriter, op, getM(), getR());
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(op, "not supported");
+ });
+
+ if (failed(maybeTransformed))
+ return emitDefaultSilenceableFailure(target);
+
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 6b46f9e07abf8..843db0c069813 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -324,6 +324,12 @@ class WinogradConv2DNhwcFhwc final
} // end anonymous namespace
//===----------------------------------------------------------------------===//
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+ linalg::Conv2DNhwcFhwcOp op, int64_t m,
+ int64_t r) {
+ return winogradConv2DHelper(rewriter, op, m, r);
+}
+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r) {
MLIRContext *context = patterns.getContext();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
new file mode 100644
index 0000000000000..1e74fea5a1c31
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = tensor.empty() : tensor<2x8x8x2xf32>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x8x8x2xf32>
+ %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %2 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
+// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+ %0 = tensor.empty() : tensor<2x9x9x2xf32>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x9x9x2xf32>
+ %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+ return %2 : tensor<2x9x9x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @conv2d_unaligned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
+// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x9x9x2xf32>
+// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
+// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
+// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
+// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
+// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
+// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
+// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
+// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
+// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
+// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
+// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
+// CHECK-NEXT: }
>From cc23f43cfab82f1c0b9ddbf6cacd29a20f99d825 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 26 Jun 2024 12:26:15 +0100
Subject: [PATCH 6/6] Address ftynse's comments
---
.../Linalg/TransformOps/LinalgTransformOps.td | 8 +-
.../TransformOps/LinalgTransformOps.cpp | 26 ++--
.../Linalg/transform-winograd-conv2d.mlir | 112 ++++++++----------
3 files changed, 69 insertions(+), 77 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..5ef56bc97fef1 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2597,7 +2597,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.
@@ -2612,9 +2612,9 @@ def WinogradConv2DOp : Op<Transform_Dialect,
#### Return modes:
- This operation fails if `target` is unsupported. Otherwise, the operation
- succeeds and returns a handle of the sequence that replaces the original
- convolution.
+ This operation produces a silenceable failure if `target` is unsupported.
+ Otherwise, the operation succeeds and returns a handle of the sequence that
+ replaces the original convolution.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d051b29e1f06f..e0f2d00400d63 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3489,17 +3489,21 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- auto maybeTransformed =
- TypeSwitch<Operation *, FailureOr<Operation *>>(target)
- .Case([&](linalg::Conv2DNhwcFhwcOp op) {
- return winogradConv2D(rewriter, op, getM(), getR());
- })
- .Default([&](Operation *op) {
- return rewriter.notifyMatchFailure(op, "not supported");
- });
-
- if (failed(maybeTransformed))
- return emitDefaultSilenceableFailure(target);
+ FailureOr<Operation *> maybeTransformed = failure();
+ bool supported = TypeSwitch<Operation *, bool>(target)
+ .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+ maybeTransformed =
+ winogradConv2D(rewriter, op, getM(), getR());
+ return true;
+ })
+ .Default([&](Operation *op) {
+ op->emitError("not supported");
+ return false;
+ });
+
+ if (supported && failed(maybeTransformed)) {
+ return emitSilenceableError() << "apply Winograd Conv2D failed";
+ }
results.push_back(*maybeTransformed);
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index 1e74fea5a1c31..0a2dcc035ebd3 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -1,13 +1,8 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
-func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
- %0 = tensor.empty() : tensor<2x8x8x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x8x8x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
- return %2 : tensor<2x8x8x2xf32>
+func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %0 : tensor<2x8x8x2xf32>
}
module attributes {transform.with_named_sequence} {
@@ -18,38 +13,17 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @conv2d
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
-// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
-// CHECK-NEXT: }
+// CHECK: linalg.winograd_filter_transform m(4) r(3)
+// CHECK: linalg.winograd_input_transform m(4) r(3)
+// CHECK: linalg.batch_matmul
+// CHECK: linalg.winograd_output_transform m(4) r(3)
// -----
-func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
- %0 = tensor.empty() : tensor<2x9x9x2xf32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<2x9x9x2xf32>
- %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
- return %2 : tensor<2x9x9x2xf32>
+func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
+ return %0 : tensor<2x9x9x2xf32>
}
module attributes {transform.with_named_sequence} {
@@ -60,29 +34,43 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @conv2d_unaligned
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
-// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
-// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
-// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
-// CHECK-NEXT: linalg.yield %[[IN]] : f32
-// CHECK-NEXT: } -> tensor<2x9x9x2xf32>
-// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
-// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
-// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
-// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
-// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
-// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
-// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
-// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
-// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
-// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
-// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
-// CHECK-NEXT: }
+// CHECK: linalg.winograd_filter_transform m(4) r(3)
+// CHECK: tensor.pad
+// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: linalg.winograd_input_transform m(4) r(3)
+// CHECK: tensor.pad
+// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
+// CHECK: linalg.winograd_output_transform m(4) r(3)
+
+// -----
+
+func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
+ // expected-error @+1 {{not supported}}
+ %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
+ return %0 : tensor<2x8x8x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
+ return %0 : tensor<2x?x?x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @+1 {{apply Winograd Conv2D failed}}
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
More information about the llvm-branch-commits
mailing list