[Mlir-commits] [mlir] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm (PR #96181)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 26 08:52:42 PDT 2024
================
@@ -0,0 +1,334 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+/// This function generates linalg.batch_matmul to multiply input with filter.
+/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
+/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
+/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
+/// way, we can convert 6-dimensional inputs to 3-dimensional representation
+/// that is suitable for linalg.batch_matmul.
+///
+/// Batched matmul will do the matrix multiply with the reduction on channel.
+///
+/// We get
+///
+/// %collapsed_input = tensor.collapse_shape %input
+/// %collapsed_filter = tensor.collapse_shape %filter
+/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+/// %expanded_ret = tensor.expand_shape %ret
+///
+/// After this function, we get return value with data layout
+/// (tileH, tileW, H, W, N, F).
+static Value matrixMultiply(RewriterBase &rewriter, Location loc,
+ Value transformedFilter, Value transformedInput,
+ Type outputElementType) {
+ // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
+ auto filterType = cast<ShapedType>(transformedFilter.getType());
+ assert(filterType.hasStaticShape() && "only support static shapes.");
+ ArrayRef<int64_t> filterShape = filterType.getShape();
+ Type filterElementType = filterType.getElementType();
+ auto filterReassocType = RankedTensorType::get(
+ {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
+ filterElementType);
+ SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
+ Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
+ loc, filterReassocType, transformedFilter, filterReassoc);
+
+ // Convert (alphaH, alphaW, tileH, tileW, N, C) to
+ // (alphaH x alphaW, tileH x tileW x N, C) for input.
+ auto inputType = cast<ShapedType>(transformedInput.getType());
+ assert(inputType.hasStaticShape() && "only support static shapes.");
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ Type inputElementType = inputType.getElementType();
+ auto inputReassocType = RankedTensorType::get(
+ {inputShape[0] * inputShape[1],
+ inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
+ inputElementType);
+ SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+ Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
+ loc, inputReassocType, transformedInput, inputReassoc);
+
+ // Batched matrix multiply.
+ auto matmulType = RankedTensorType::get(
+ {inputShape[0] * inputShape[1],
+ inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
+ outputElementType);
+ Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ outputElementType);
+
+ auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+ ValueRange{init});
+
+ // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
+ // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
+ SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
+ auto outputReassocType =
+ RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
+ inputShape[3], inputShape[4], filterShape[3]},
+ outputElementType);
+ auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+ loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
+ return expandOutput;
+}
+
+/// Create an empty tensor with alignedType and insert the value into the
+/// created empty tensor with aligned size.
+static Value insertToAlignedTensor(RewriterBase &rewriter, Location loc,
+ Value value,
+ ArrayRef<int64_t> alignedShape) {
+ OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+ auto valueType = cast<ShapedType>(value.getType());
+ Type elementType = valueType.getElementType();
+ ArrayRef<int64_t> valueShape = valueType.getShape();
+ SmallVector<OpFoldResult, 6> lowIndices(alignedShape.size(), zeroIndex);
+ SmallVector<OpFoldResult, 6> highIndices;
+ for (unsigned i = 0; i < alignedShape.size(); ++i) {
+ highIndices.emplace_back(
+ rewriter.getIndexAttr(alignedShape[i] - valueShape[i]));
+ }
+ auto alignedType = RankedTensorType::get(alignedShape, elementType);
+ Value pad_value = rewriter.create<arith::ConstantOp>(
+ loc, elementType, rewriter.getZeroAttr(elementType));
+ return rewriter.create<tensor::PadOp>(loc, alignedType, value, lowIndices,
----------------
Max191 wrote:
nit: Can you use this util?
https://github.com/llvm/llvm-project/blob/17eaa23f7ecdfe79ad74552aaa260e6ce32432c2/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h#L75
https://github.com/llvm/llvm-project/pull/96181
More information about the Mlir-commits
mailing list