[Mlir-commits] [mlir] [mlir][linalg] Implement Winograd Conv2D. (PR #94470)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Fri Jun 14 11:37:30 PDT 2024
================
@@ -0,0 +1,1022 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// Implement Winograd Conv2D algorithm. The implementation is based on the
+/// paper: Fast Algorithms for Convolutional Neural Networks
+/// (https://arxiv.org/abs/1509.09308)
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+ -1, 0, 0,
+ 1./2, -1./2, 1./2,
+ 1./2, 1./2, 1./2,
+ 0, 0, 1
+};
+
+constexpr float GT_2x2_3x3[] = {
+ -1, 1./2, 1./2, 0,
+ 0, -1./2, 1./2, 0,
+ 0, 1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+ -1, 0, 1, 0,
+ 0, -1, 1, 0,
+ 0, 1, 1, 0,
+ 0, -1, 0, 1
+};
+
+constexpr float B_2x2_3x3[] = {
+ -1, 0, 0, 0,
+ 0, -1, 1, -1,
+ 1, 1, 1, 0,
+ 0, 0, 0, 1
+};
+
+constexpr float AT_2x2_3x3[] = {
+ 1, 1, 1, 0,
+ 0, -1, 1, 1
+};
+
+constexpr float A_2x2_3x3[] = {
+ 1, 0,
+ 1, -1,
+ 1, 1,
+ 0, 1
+};
+
+constexpr float G_4x4_3x3[] = {
+ 1, 0, 0,
+ -1./3, 1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6, 1./3,
+ 1./12, 1./6, 1./3,
+ 0, 0, 1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1, -1./3, -1./3, 1./12, 1./12, 0,
+ 0, 1./3, -1./3, -1./6, 1./6, 0,
+ 0, -1./3, -1./3, 1./3, 1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4, 0, -5./16, 0, 1./16, 0,
+ 0, 1./4, -1./4, -1./16, 1./16, 0,
+ 0, -1./4, -1./4, 1./16, 1./16, 0,
+ 0, 1./4, -1./8, -1./4, 1./8, 0,
+ 0, -1./4, -1./8, 1./4, 1./8, 0,
+ 0, 1./4, 0, -5./16, 0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+ 1./4, 0, 0, 0, 0, 0,
+ 0, 1./4, -1./4, 1./4, -1./4, 1./4,
+ -5./16, -1./4, -1./4, -1./8, -1./8, 0,
+ 0, -1./16, 1./16, -1./4, 1./4, -5./16,
+ 1./16, 1./16, 1./16, 1./8, 1./8, 0,
+ 0, 0, 0, 0, 0, 1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8, 1./4, 1./4, 1./8, 1./8, 0,
+ 0, -1./4, 1./4, -1./4, 1./4, 0,
+ 0, 1./4, 1./4, 1./2, 1./2, 0,
+ 0, -1./4, 1./4, -1, 1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+ 1./8, 0, 0, 0,
+ 1./4, -1./4, 1./4, -1./4,
+ 1./4, 1./4, 1./4, 1./4,
+ 1./8, -1./4, 1./2, -1,
+ 1./8, 1./4, 1./2, 1,
+ 0, 0, 0, 1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+ 1, 0, 0, 0, 0,
+ 1./6, -1./6, 1./6, -1./6, 1./6,
+ -1./6, -1./6, -1./6, -1./6, -1./6,
+-4./15, 2./15, -1./15, 1./30, -1./60,
+ 1./60, 1./30, 1./15, 2./15, 4./15,
+ 0, 0, 0, 0, 1
+};
+
+constexpr float GT_2x2_5x5[] = {
+ 1, 1./6, -1./6, -4./15, 1./60, 0,
+ 0, -1./6, -1./6, 2./15, 1./30, 0,
+ 0, 1./6, -1./6, -1./15, 1./15, 0,
+ 0, -1./6, -1./6, 1./30, 2./15, 0,
+ 0, 1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8, 3./16, -1./4, -3./16, 1./8, 0,
+ 0, 1./8, 1./16, -5./16, 1./8, 0,
+ 0, -1./8, -5./16, -1./16, 1./8, 0,
+ 0, 1./4, -1./8, -1./4, 1./8, 0,
+ 0, -1./8, -1./4, 1./8, 1./4, 0,
+ 0, 1./8, 3./16, -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+ 1./8, 0, 0, 0, 0, 0,
+ 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
+ -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
+ -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
+ 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
+ 0, 0, 0, 0, 0, 1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+ 1./2, 1, 1, 2, 1, 0,
+ 0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2, 0,
+ 1, -1,
+ 1, 1,
+ 2, -1,
+ 1, 2,
+ 0, 1./2
+};
+// clang-format on
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+struct TransformMatrix {
+ TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ int64_t scalarFactor = 1)
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+ const float *table;
+ int64_t rows;
+ int64_t cols;
+ int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+ TransformMatrix transform, Type type) {
+ ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseFPElementsAttr::get(
+ RankedTensorType::get(
+ SmallVector<int64_t>{transform.rows, transform.cols}, type),
+ const_vec));
+}
+
+Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+ Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+ int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ auto sourceShape = sourceType.getShape();
+ int64_t height = sourceShape[heightIdx];
+ int64_t width = sourceShape[widthIdx];
+
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+ offsets[outLoopIdx] = outLoopIndex;
+ offsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 4> sizes(4, oneIndex);
+ sizes[heightIdx] = rewriter.getIndexAttr(height);
+ sizes[widthIdx] = rewriter.getIndexAttr(width);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+ SmallVector<int64_t> targetShape(4, 1);
+ targetShape[heightIdx] = height;
+ targetShape[widthIdx] = width;
+
+ auto targetType = RankedTensorType::get(targetShape, elementType);
+ auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, targetType, source, offsets, sizes, strides);
+
+ auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+ auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, extractFilterOp, extractFilterType);
+
+ return extractFilter;
+}
+
+Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
+ Value dest, Value outLoopIndex, Value inLoopIndex,
+ int64_t height, int64_t width, int64_t outLoopIdx,
+ int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ SmallVector<int64_t> sliceShape(4, 1);
+ sliceShape[heightIdx] = height;
+ sliceShape[widthIdx] = width;
+ auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+ auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+ source, init);
+
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> retOffsets(4, zeroIndex);
+ retOffsets[outLoopIdx] = outLoopIndex;
+ retOffsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 4> retSizes(4, oneIndex);
+ retSizes[heightIdx] = rewriter.getIndexAttr(height);
+ retSizes[widthIdx] = rewriter.getIndexAttr(width);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+ auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, result, dest, retOffsets, retSizes, strides);
+
+ return insertSliceOp;
+}
+
+Value collaps2DData(RewriterBase &rewriter, Location loc, Value data) {
----------------
ftynse wrote:
collapse
https://github.com/llvm/llvm-project/pull/94470
More information about the Mlir-commits
mailing list