[Mlir-commits] [mlir] [mlir][linalg] Implement Winograd Conv2D. (PR #94470)
Hsiangkai Wang
llvmlistbot at llvm.org
Thu Jun 20 02:57:26 PDT 2024
================
@@ -0,0 +1,1022 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// Implement Winograd Conv2D algorithm. The implementation is based on the
+/// paper: Fast Algorithms for Convolutional Neural Networks
+/// (https://arxiv.org/abs/1509.09308)
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+ -1, 0, 0,
+ 1./2, -1./2, 1./2,
+ 1./2, 1./2, 1./2,
+ 0, 0, 1
+};
+
+constexpr float GT_2x2_3x3[] = {
+ -1, 1./2, 1./2, 0,
+ 0, -1./2, 1./2, 0,
+ 0, 1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+ -1, 0, 1, 0,
+ 0, -1, 1, 0,
+ 0, 1, 1, 0,
+ 0, -1, 0, 1
+};
+
+constexpr float B_2x2_3x3[] = {
+ -1, 0, 0, 0,
+ 0, -1, 1, -1,
+ 1, 1, 1, 0,
+ 0, 0, 0, 1
+};
+
+constexpr float AT_2x2_3x3[] = {
+ 1, 1, 1, 0,
+ 0, -1, 1, 1
+};
+
+constexpr float A_2x2_3x3[] = {
+ 1, 0,
+ 1, -1,
+ 1, 1,
+ 0, 1
+};
+
+constexpr float G_4x4_3x3[] = {
+ 1, 0, 0,
+ -1./3, 1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6, 1./3,
+ 1./12, 1./6, 1./3,
+ 0, 0, 1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1, -1./3, -1./3, 1./12, 1./12, 0,
+ 0, 1./3, -1./3, -1./6, 1./6, 0,
+ 0, -1./3, -1./3, 1./3, 1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4, 0, -5./16, 0, 1./16, 0,
+ 0, 1./4, -1./4, -1./16, 1./16, 0,
+ 0, -1./4, -1./4, 1./16, 1./16, 0,
+ 0, 1./4, -1./8, -1./4, 1./8, 0,
+ 0, -1./4, -1./8, 1./4, 1./8, 0,
+ 0, 1./4, 0, -5./16, 0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+ 1./4, 0, 0, 0, 0, 0,
+ 0, 1./4, -1./4, 1./4, -1./4, 1./4,
+ -5./16, -1./4, -1./4, -1./8, -1./8, 0,
+ 0, -1./16, 1./16, -1./4, 1./4, -5./16,
+ 1./16, 1./16, 1./16, 1./8, 1./8, 0,
+ 0, 0, 0, 0, 0, 1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8, 1./4, 1./4, 1./8, 1./8, 0,
+ 0, -1./4, 1./4, -1./4, 1./4, 0,
+ 0, 1./4, 1./4, 1./2, 1./2, 0,
+ 0, -1./4, 1./4, -1, 1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+ 1./8, 0, 0, 0,
+ 1./4, -1./4, 1./4, -1./4,
+ 1./4, 1./4, 1./4, 1./4,
+ 1./8, -1./4, 1./2, -1,
+ 1./8, 1./4, 1./2, 1,
+ 0, 0, 0, 1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+ 1, 0, 0, 0, 0,
+ 1./6, -1./6, 1./6, -1./6, 1./6,
+ -1./6, -1./6, -1./6, -1./6, -1./6,
+-4./15, 2./15, -1./15, 1./30, -1./60,
+ 1./60, 1./30, 1./15, 2./15, 4./15,
+ 0, 0, 0, 0, 1
+};
+
+constexpr float GT_2x2_5x5[] = {
+ 1, 1./6, -1./6, -4./15, 1./60, 0,
+ 0, -1./6, -1./6, 2./15, 1./30, 0,
+ 0, 1./6, -1./6, -1./15, 1./15, 0,
+ 0, -1./6, -1./6, 1./30, 2./15, 0,
+ 0, 1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8, 3./16, -1./4, -3./16, 1./8, 0,
+ 0, 1./8, 1./16, -5./16, 1./8, 0,
+ 0, -1./8, -5./16, -1./16, 1./8, 0,
+ 0, 1./4, -1./8, -1./4, 1./8, 0,
+ 0, -1./8, -1./4, 1./8, 1./4, 0,
+ 0, 1./8, 3./16, -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+ 1./8, 0, 0, 0, 0, 0,
+ 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
+ -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
+ -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
+ 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
+ 0, 0, 0, 0, 0, 1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+ 1./2, 1, 1, 2, 1, 0,
+ 0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2, 0,
+ 1, -1,
+ 1, 1,
+ 2, -1,
+ 1, 2,
+ 0, 1./2
+};
+// clang-format on
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+struct TransformMatrix {
+ TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ int64_t scalarFactor = 1)
+ : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+ const float *table;
+ int64_t rows;
+ int64_t cols;
+ int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+ TransformMatrix transform, Type type) {
+ ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseFPElementsAttr::get(
+ RankedTensorType::get(
+ SmallVector<int64_t>{transform.rows, transform.cols}, type),
+ const_vec));
+}
+
+Value extract2DData(RewriterBase &rewriter, Location loc, Value source,
+ Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx,
+ int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ auto sourceShape = sourceType.getShape();
+ int64_t height = sourceShape[heightIdx];
+ int64_t width = sourceShape[widthIdx];
+
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
+ offsets[outLoopIdx] = outLoopIndex;
+ offsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 4> sizes(4, oneIndex);
+ sizes[heightIdx] = rewriter.getIndexAttr(height);
+ sizes[widthIdx] = rewriter.getIndexAttr(width);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+ SmallVector<int64_t> targetShape(4, 1);
+ targetShape[heightIdx] = height;
+ targetShape[widthIdx] = width;
+
+ auto targetType = RankedTensorType::get(targetShape, elementType);
+ auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, targetType, source, offsets, sizes, strides);
+
+ auto extractFilterType = RankedTensorType::get({height, width}, elementType);
+ auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, extractFilterOp, extractFilterType);
+
+ return extractFilter;
+}
+
+Value insert2DData(RewriterBase &rewriter, Location loc, Value source,
+ Value dest, Value outLoopIndex, Value inLoopIndex,
+ int64_t height, int64_t width, int64_t outLoopIdx,
+ int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx) {
+ auto sourceType = cast<ShapedType>(source.getType());
+ Type elementType = sourceType.getElementType();
+ SmallVector<int64_t> sliceShape(4, 1);
+ sliceShape[heightIdx] = height;
+ sliceShape[widthIdx] = width;
+ auto init = rewriter.create<tensor::EmptyOp>(loc, sliceShape, elementType);
+ auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc,
+ source, init);
+
+ auto zeroIndex = rewriter.getIndexAttr(0);
+ auto oneIndex = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult, 4> retOffsets(4, zeroIndex);
+ retOffsets[outLoopIdx] = outLoopIndex;
+ retOffsets[inLoopIdx] = inLoopIndex;
+ SmallVector<OpFoldResult, 4> retSizes(4, oneIndex);
+ retSizes[heightIdx] = rewriter.getIndexAttr(height);
+ retSizes[widthIdx] = rewriter.getIndexAttr(width);
+ SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+ auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, result, dest, retOffsets, retSizes, strides);
+
+ return insertSliceOp;
+}
+
+Value collaps2DData(RewriterBase &rewriter, Location loc, Value data) {
+ auto type = cast<ShapedType>(data.getType());
+ auto elementType = type.getElementType();
+ auto shape = type.getShape();
+ auto collapseType = RankedTensorType::get(
+ {shape[0] * shape[1], shape[2], shape[3]}, elementType);
+ SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+ reassociation);
+}
+
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+// scf.for %c = lo_c to hi_c step 1
+// %extracted = extract filter<h x w> from filter<f x h x w x c>
+// %ret = linalg.matmul G, %extracted
+// %ret = linalg.matmul %ret, GT
+// %inserted = insert %ret into filter<h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+ Value retValue, int64_t m, int64_t r,
+ bool leftTransform = true, bool rightTransform = true) {
+ // Map from (m, r) to G transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ GMatrices = {
+ {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+ {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+ {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+ };
+
+ // Map from (m, r) to GT transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ GTMatrices = {
+ {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+ {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+ {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+ };
+
+ auto filterType = cast<ShapedType>(filter.getType());
+ Type elementType = filterType.getElementType();
+ auto filterShape = filterType.getShape(); // F, H, W, C
+ int64_t filterF = filterShape[0];
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t filterC = filterShape[3];
+
+ if (filterH != r && filterH != 1)
+ return Value();
+ if (filterW != r && filterW != 1)
+ return Value();
+
+ // Return shape is <H x W x C x F>
+ auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+ auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+ auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto outerForOp =
+ rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+ Block *outerForBody = outerForOp.getBody();
+ rewriter.setInsertionPointToStart(outerForBody);
+ Value FIter = outerForBody->getArgument(0);
+
+ auto innerForOp = rewriter.create<scf::ForOp>(
+ loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+ Block *innerForBody = innerForOp.getBody();
+ rewriter.setInsertionPointToStart(innerForBody);
+ Value CIter = innerForBody->getArgument(0);
+
+ // Extract (H, W) from (F, H, W, C)
+ auto extractFilter =
+ extract2DData(rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0,
+ /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+ TransformMapKeyTy key = {m, r};
+ int64_t retRows = 1;
+ Value matmulRetValue = extractFilter;
+ if (leftTransform) {
+ // Get constant transform matrix G
+ auto it = GMatrices.find(key);
+ if (it == GMatrices.end())
+ return Value();
+ const TransformMatrix &GMatrix = it->second;
+
+ retRows = GMatrix.rows;
+ auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+ // Multiply G x g
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ if (rightTransform) {
+ // Get constant transform matrix GT
+ auto it = GTMatrices.find(key);
+ if (it == GTMatrices.end())
+ return Value();
+ const TransformMatrix >Matrix = it->second;
+
+ auto matmulType =
+ RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+ // Multiply u = (G x g) x GT
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ // Insert (H, W) to (H, W, C, F)
+ Value iterArg = innerForOp.getRegionIterArgs()[0];
+ int64_t retHeight = leftTransform ? m + r - 1 : 1;
+ int64_t retWidth = rightTransform ? m + r - 1 : 1;
+ auto insertSliceOp = insert2DData(
+ rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth,
+ /*outLoopIdx=*/3, /*inLoopIdx=*/2, /*heightIdx=*/0, /*widthIdx=*/1);
+
+ rewriter.create<scf::YieldOp>(loc, insertSliceOp);
+
+ rewriter.setInsertionPointToEnd(outerForBody);
+ rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+ rewriter.setInsertionPointAfter(outerForOp);
+
+ return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+// scf.for %c = lo_c to hi_c step 1
+// %extracted = extract input<h x w> from input<n x h x w x c>
+// %ret = linalg.matmul BT, %extracted
+// %ret = linalg.matmul %ret, B
+// %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+ Value retValue, int64_t m, int64_t r,
+ bool leftTransform = true, bool rightTransform = true) {
+ // Map from (m, r) to BT transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ BTMatrices = {
+ {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+ {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+ {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+ };
+
+ // Map from (m, r) to B transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ BMatrices = {
+ {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+ {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+ {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+ };
+
+ auto inputType = cast<ShapedType>(input.getType());
+ Type elementType = inputType.getElementType();
+ auto inputShape = inputType.getShape(); // N, H, W, C
+ int64_t inputN = inputShape[0];
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t inputC = inputShape[3];
+ int64_t alphaH = leftTransform ? m + r - 1 : 1;
+ int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+ if (inputH != alphaH && inputH != 1)
+ return Value();
+ if (inputW != alphaW && inputW != 1)
+ return Value();
+
+ auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+ auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+ auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+ auto outerForOp =
+ rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+ Block *outerForBody = outerForOp.getBody();
+ rewriter.setInsertionPointToStart(outerForBody);
+ Value NIter = outerForBody->getArgument(0);
+
+ auto innerForOp = rewriter.create<scf::ForOp>(
+ loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+ Block *innerForBody = innerForOp.getBody();
+ rewriter.setInsertionPointToStart(innerForBody);
+ Value CIter = innerForBody->getArgument(0);
+
+ // Extract (H, W) from (N, H, W, C)
+ auto extractInput =
+ extract2DData(rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0,
+ /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+ TransformMapKeyTy key = {m, r};
+ int64_t retRows = 1;
+ int64_t retCols = 1;
+ Value matmulRetValue = extractInput;
+ if (leftTransform) {
+ // Get constant transform matrix BT
+ auto it = BTMatrices.find(key);
+ if (it == BTMatrices.end())
+ return Value();
+ const TransformMatrix &BTMatrix = it->second;
+
+ retRows = BTMatrix.rows;
+ auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value BT =
+ create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+ // Multiply BT x d
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ if (rightTransform) {
+ // Get constant transform matrix B
+ auto it = BMatrices.find(key);
+ if (it == BMatrices.end())
+ return Value();
+ const TransformMatrix &BMatrix = it->second;
+
+ retCols = BMatrix.cols;
+ auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+ Value B =
+ create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+ // Multiply v = (BT x d) x B
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ // Insert v
+ // Insert (H, W) to (H, W, N, C)
+ Value iterArg = innerForOp.getRegionIterArgs()[0];
+ auto combinedVal = insert2DData(
+ rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols,
+ /*outLoopIdx=*/2, /*inLoopIdx=*/3, /*heightIdx=*/0, /*widthIdx=*/1);
+
+ rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+ rewriter.setInsertionPointToEnd(outerForBody);
+ rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+ rewriter.setInsertionPointAfter(outerForOp);
+
+ return outerForOp.getResult(0);
+}
+
+// This function generates linalg.batch_matmul to multiply input with filter.
+// linalg.batch_matmul only supports 3-dimension data sets. We can treat H x W
+// data as the 1-dimension data array. That is to convert [H, W, N, C] to
+// [H x W, N, C]. In this way, we can convert 4-dimension input data to
+// 3-dimension representation that is suitable for linalg.batch_matmul.
+//
+// Batched matmul will do the matrix multiply with the reduction on channel.
+//
+// We get
+//
+// %collapsed_input = tensor.collapse_shape %input
+// %collapsed_filter = tensor.collapse_shape %filter
+// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+// %expanded_ret = tensor.expand_shape %ret
+//
+// After this function, we get return value with data layout (H, W, N, F)
+//
+Value matrixMultiply(RewriterBase &rewriter, Location loc,
+ Value transformedFilter, Value transformedInput) {
+ auto collapseFilter = collaps2DData(rewriter, loc, transformedFilter);
+ auto collapseInput = collaps2DData(rewriter, loc, transformedInput);
+
+ // Batched matrix multiply
+ auto filterType = cast<ShapedType>(transformedFilter.getType());
+ auto filterShape = filterType.getShape();
+ auto inputType = cast<ShapedType>(transformedInput.getType());
+ auto inputElemType = inputType.getElementType();
+ auto inputShape = inputType.getShape();
+
+ auto matmulType = RankedTensorType::get(
+ {inputShape[0] * inputShape[1], inputShape[2], filterShape[3]},
+ inputElemType);
+ Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ inputElemType);
+
+ auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, matmulType, ValueRange({collapseInput, collapseFilter}),
+ ValueRange{init});
+
+ // Expand matmul result
+ SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
+ auto expandType = RankedTensorType::get(
+ {inputShape[0], inputShape[1], inputShape[2], filterShape[3]},
+ inputElemType);
+ auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandType, matmulOp.getResult(0), reassociation);
+ return expandOutput;
+}
+
+// This function transforms the output. The data layout of the output is HWNF.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+// scf.for %f = lo_f to hi_f step 1
+// %extracted = extract input<h x w> from result<h x w x n x f>
+// %ret = linalg.matmul AT, %extracted
+// %ret = linalg.matmul %ret, A
+// %inserted = insert %ret into ret<n x h x w x f>
+//
+Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
+ Value output, int64_t m, int64_t r,
+ bool leftTransform = true, bool rightTransform = true) {
+ // Map from (m, r) to AT transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ ATMatrices = {
+ {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+ {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+ {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+ };
+
+ // Map from (m, r) to A transform matrix.
+ static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+ AMatrices = {
+ {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+ {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+ {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+ };
+
+ auto valueType = cast<ShapedType>(value.getType());
+ Type elementType = valueType.getElementType();
+ auto valueShape = valueType.getShape(); // H, W, N, F
+ int64_t valueH = valueShape[0];
+ int64_t valueW = valueShape[1];
+ int64_t valueN = valueShape[2];
+ int64_t valueF = valueShape[3];
+ int64_t alphaH = leftTransform ? m + r - 1 : 1;
+ int64_t alphaW = rightTransform ? m + r - 1 : 1;
+
+ if (valueH != alphaH && valueH != 1)
+ return Value();
+ if (valueW != alphaW && valueW != 1)
+ return Value();
+
+ auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
+ auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
+ auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+ auto outerForOp =
+ rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, output);
+ Block *outerForBody = outerForOp.getBody();
+ rewriter.setInsertionPointToStart(outerForBody);
+ Value NIter = outerForBody->getArgument(0);
+
+ auto innerForOp = rewriter.create<scf::ForOp>(
+ loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+ Block *innerForBody = innerForOp.getBody();
+ rewriter.setInsertionPointToStart(innerForBody);
+ Value FIter = innerForBody->getArgument(0);
+
+ // Extract (H, W) from (H, W, N, F)
+ auto extractValue =
+ extract2DData(rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/2,
+ /*inLoopIdx=*/3, /*heightIdx=*/0, /*widthIdx=*/1);
+
+ TransformMapKeyTy key = {m, r};
+ int64_t retRows = 1;
+ int64_t retCols = 1;
+ int64_t leftScalarFactor = 1;
+ int64_t rightScalarFactor = 1;
+ Value matmulRetValue = extractValue;
+ if (leftTransform) {
+ // Get constant transform matrix AT
+ auto it = ATMatrices.find(key);
+ if (it == ATMatrices.end())
+ return Value();
+ const TransformMatrix &ATMatrix = it->second;
+
+ leftScalarFactor = ATMatrix.scalarFactor;
+ retRows = ATMatrix.rows;
+ auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType);
+ // Multiply AT x m
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ if (rightTransform) {
+ // Get constant transform matrix T
+ auto it = AMatrices.find(key);
+ if (it == AMatrices.end())
+ return Value();
+ const TransformMatrix &AMatrix = it->second;
+
+ rightScalarFactor = AMatrix.scalarFactor;
+ auto matmulType =
+ RankedTensorType::get({retRows, AMatrix.cols}, elementType);
+ retCols = AMatrix.cols;
+ auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ elementType);
+
+ Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType);
+ // Multiply y = (AT x m) x A
+ auto matmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
+ matmulRetValue = matmulOp.getResult(0);
+ }
+
+ // Multiply scalar factor.
+ Value scalarFactor = rewriter.create<arith::ConstantOp>(
+ loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
+ auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+ auto init =
+ rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType);
+
+ auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
+ SmallVector<AffineMap> affineMaps = {AffineMap::get(2, 0, init.getContext()),
+ identityAffineMap, identityAffineMap};
+ auto scalarMatrixOp = rewriter.create<linalg::GenericOp>(
+ loc, matmulType, ValueRange{scalarFactor, matmulRetValue},
+ ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value scalarVal = args[0];
+ Value matrixVal = args[1];
+ Value result = nestedBuilder.create<arith::MulFOp>(nestedLoc, scalarVal,
+ matrixVal);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
+
+ // Insert slice y
+ // Insert (H, W) to (N, H, W, F)
+ Value iterArg = innerForOp.getRegionIterArgs()[0];
+ Value combinedVal =
+ insert2DData(rewriter, loc, scalarMatrixOp.getResult(0), iterArg, NIter,
+ FIter, retRows, retCols,
+ /*outLoopIdx=*/0,
+ /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
+
+ rewriter.create<scf::YieldOp>(loc, combinedVal);
+
+ rewriter.setInsertionPointToEnd(outerForBody);
+ rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+ rewriter.setInsertionPointAfter(outerForOp);
+
+ return outerForOp.getResult(0);
+}
+
+FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
+ linalg::Conv2DNhwcFhwcOp convOp,
+ int64_t m, int64_t r) {
+ Value input = convOp.getInputs()[0];
+ Value filter = convOp.getInputs()[1];
+ Value output = convOp.getOutputs()[0];
+
+ auto outputType = cast<ShapedType>(output.getType());
+ int64_t outputH = outputType.getShape()[1];
+ int64_t outputW = outputType.getShape()[2];
+ auto filterType = cast<ShapedType>(filter.getType());
+ auto filterShape = filterType.getShape(); // F, H, W, C
+ int64_t filterF = filterShape[0];
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t filterC = filterShape[3];
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape(); // N, H, W, C
+ int64_t inputN = inputShape[0];
+ int64_t inputC = inputShape[3];
+
+ // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
+ if ((outputH != outputW) && (outputH != 1 && outputW != 1))
+ return failure();
+ if ((filterH != filterW) && (filterH != 1 && filterW != 1))
+ return failure();
+
+ if ((outputH == 1 && filterH != 1) || (outputH != 1 && filterH == 1))
+ return failure();
+ if ((outputW == 1 && filterW != 1) || (outputW != 1 && filterW == 1))
+ return failure();
+
+ // Map from (m, r) to G transform matrix.
+ static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
+ F_2_3, F_4_3, F_2_5};
+
+ TransformMapKeyTy key = {m, r};
+ auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
+ // If we cannot find the constant transformation matrix, it means we do
+ // not support this configuration yet.
+ if (it == validConfigs.end())
+ return failure();
+
+ // All the criterias are satisfied. We can do Winograd Conv2D.
+ Location loc = convOp.getLoc();
+
+ // For F(m x 1, r x 1), we only need to do left side transform.
+ bool leftTransform = outputH != 1;
+ // For F(1 x m, 1 x r), we only need to do right side transform.
+ bool rightTransform = outputW != 1;
+
+ // Create operator for filter transform
+ Type elementType = filterType.getElementType();
+ int64_t alphaH = leftTransform ? m + r - 1 : 1;
+ int64_t alphaW = rightTransform ? m + r - 1 : 1;
+ int64_t retHeight = leftTransform ? (outputH / m) * alphaH : 1;
+ int64_t retWidth = rightTransform ? (outputW / m) * alphaW : 1;
----------------
Hsiangkai wrote:
Updated the patch to consider unaligned cases and added tests for unaligned cases.
https://github.com/llvm/llvm-project/pull/94470
More information about the Mlir-commits
mailing list