[llvm-branch-commits] [mlir] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (PR #96182)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 20 05:41:12 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
---
Patch is 57.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96182.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+114)
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+51)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+11)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+78)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+25)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+327)
- (added) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+88)
- (added) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+248)
- (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+13)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..de1097b6ac27b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+ let summary = "Winograd filter transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of filter
+ transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins AnyRankedTensor:$filter,
+ AnyRankedTensor:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $filter `:` type($filter) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+ let summary = "Winograd input transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of input
+ transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins AnyRankedTensor:$input,
+ AnyRankedTensor:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $input `:` type($input) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+ let summary = "Winograd output transform operator";
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ This operator is defined to represent the high level concept of output
+ transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+ }];
+
+ let arguments = (ins AnyRankedTensor:$value,
+ AnyRankedTensor:$output,
+ I64Attr:$m,
+ I64Attr:$r
+ );
+
+ let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ attr-dict
+ `m` `(` $m `)`
+ `r` `(` $r `)`
+ `ins` `(` $value `:` type($value) `)`
+ `outs` `(` $output `:` type($output) `)`
+ `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..68d0f713caad4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
}];
}
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+ "structured.winograd_conv2d",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+ matrix multiply. Before the matrix multiply, it will convert filter and
+ input into a format suitable for batched matrix multiply. After the matrix
+ multiply, it will convert output to the final result tensor.
+
+ The algorithm F(m x m, r x r) is
+
+ Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+ The size of output Y is m x m. The size of filter g is r x r. The size of
+ input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+ transformation matrices.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported. Otherwise, the operation
+ succeeds and returns a handle of the sequence that replaces the original
+ convolution.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ I64Attr:$m,
+ I64Attr:$r);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..da107b66257a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+ linalg::Conv2DNhwcFhwcOp op, int64_t m,
+ int64_t r);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
@@ -1692,6 +1699,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+ int64_t r);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..7bf2a5bca037f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>{result};
}
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+ auto filterType = cast<ShapedType>(getFilter().getType());
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ auto filterElemType = filterType.getElementType();
+ auto outputElemType = outputType.getElementType();
+ if (filterElemType != outputElemType) {
+ return emitOpError() << "expected element type of input " << filterElemType
+ << " to match element type of output "
+ << outputElemType;
+ }
+
+ unsigned filterRank = filterType.getRank();
+ if (filterRank != 4)
+ return emitOpError() << "expected rank of input is 4";
+
+ unsigned outputRank = outputType.getRank();
+ if (outputRank != 6)
+ return emitOpError() << "expected rank of output is 6";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+ auto inputType = cast<ShapedType>(getInput().getType());
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ auto inputElemType = inputType.getElementType();
+ auto outputElemType = outputType.getElementType();
+ if (inputElemType != outputElemType) {
+ return emitOpError() << "expected element type of input " << inputElemType
+ << " to match element type of output "
+ << outputElemType;
+ }
+
+ unsigned inputRank = inputType.getRank();
+ if (inputRank != 4)
+ return emitOpError() << "expected rank of input is 4";
+
+ unsigned outputRank = outputType.getRank();
+ if (outputRank != 6)
+ return emitOpError() << "expected rank of output is 6";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+ auto valueType = cast<ShapedType>(getValue().getType());
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ auto valueElemType = valueType.getElementType();
+ auto outputElemType = outputType.getElementType();
+ if (valueElemType != outputElemType) {
+ return emitOpError() << "expected element type of value " << valueElemType
+ << " to match element type of output "
+ << outputElemType;
+ }
+
+ unsigned valueRank = valueType.getRank();
+ if (valueRank != 6)
+ return emitOpError() << "expected rank of input is 6";
+
+ unsigned outputRank = outputType.getRank();
+ if (outputRank != 4)
+ return emitOpError() << "expected rank of output is 4";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..d051b29e1f06f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+ transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ auto maybeTransformed =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+ return winogradConv2D(rewriter, op, getM(), getR());
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(op, "not supported");
+ });
+
+ if (failed(maybeTransformed))
+ return emitDefaultSilenceableFailure(target);
+
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Transforms.cpp
TransposeConv2D.cpp
Vectorization.cpp
+ WinogradConv2D.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..d1f4be8bbf29a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,327 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
+ auto type = cast<ShapedType>(data.getType());
+ auto elementType = type.getElementType();
+ auto shape = type.getShape();
+ auto collapseType = RankedTensorType::get(
+ {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
+ elementType);
+ SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+ reassociation);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat
+// tileH x tileW x H x W data as the 1-dimension data array. That is to convert
+// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we
+// can convert 6-dimension input data to 3-dimension representation that is
+// suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout
+// (tileH, tileW, H, W, N, F).
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+ Value transformedFilter, Value transformedInput) {
+ auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
+ auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
+
+ // Batched matrix multiply
+ auto filterType = cast<ShapedType>(transformedFilter.getType());
+ auto filterShape = filterType.getShape();
+ auto inputType = cast<ShapedType>(transformedInput.getType());
+ auto inputElemType = inputType.getElementType();
+ auto inputShape = inputType.getShape();
+
+ auto matmulType = RankedTensorType::get(
+ {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
+ inputShape[4], filterShape[5]},
+ inputElemType);
+ Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ inputElemType);
+
+ auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+ ValueRange{init});
+
+ // Expand matmul result
+ SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+ auto expandType =
+ RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+ inputShape[3], inputShape[4], filterShape[5]},
+ inputElemType);
+ auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandType, matmulOp.getResult(0), reassociation);
+ return expandOutput;
+}
+
+Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value,
+ RankedTensorType alignedType) {
+ Value alignedInput = rewriter.create<tensor::EmptyOp>(
+ loc, alignedType.getShape(), alignedType.getElementType());
+
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+ auto valueType = cast<ShapedType>(value.getType());
+ auto valueShape = valueType.getShape();
+ SmallVector<OpFoldResult, 4> sizes;
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[0]));
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[1]));
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[2]));
+ sizes.emplace_back(rewriter.getIndexAttr(valueShape[3]));
+
+ return rewriter.create<tensor::InsertSliceOp>(loc, value, alignedInput,
+ offsets, sizes, strides);
+}
+
+Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
+ Value value, RankedTensorType extractedType) {
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+ auto extractedShape = extractedType.getShape();
+ SmallVector<OpFoldResult, 4> sizes;
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0]));
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1]));
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2]));
+ sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3]));
+
+ return rewriter.create<tensor::ExtractSliceOp>(lo...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/96182
More information about the llvm-branch-commits
mailing list