[Mlir-commits] [mlir] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm (PR #96181)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Jun 24 07:23:00 PDT 2024
    
    
  
================
@@ -0,0 +1,329 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+/// Utility function to linearize data. The input shape is
+/// [tileH, tileW, H, W, N, C] or [tileH, tileW, H, W, C, F]. The function will
+/// convert the shape to [tileH x tileW x H x W, N, C] or
+/// [tileH x tileW x H x W, C, F].
+static Value collapseData(RewriterBase &rewriter, Location loc, Value data) {
+  auto type = cast<ShapedType>(data.getType());
+  assert(type.hasStaticShape() && "only support static shapes.");
+  Type elementType = type.getElementType();
+  ArrayRef<int64_t> shape = type.getShape();
+  auto collapseType = RankedTensorType::get(
+      {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]},
+      elementType);
+  SmallVector<ReassociationIndices> reassociation = {{0, 1, 2, 3}, {4}, {5}};
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
+                                                  reassociation);
+}
+
+/// This function generates linalg.batch_matmul to multiply input with filter.
+/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
+/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
+/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
+/// way, we can convert 6-dimensional inputs to 3-dimensional representation
+/// that is suitable for linalg.batch_matmul.
+///
+/// Batched matmul will do the matrix multiply with the reduction on channel.
+///
+/// We get
+///
+/// %collapsed_input = tensor.collapse_shape %input
+/// %collapsed_filter = tensor.collapse_shape %filter
+/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
+/// %expanded_ret = tensor.expand_shape %ret
+///
+/// After this function, we get return value with data layout
+/// (tileH, tileW, H, W, N, F).
+static Value matrixMultiply(RewriterBase &rewriter, Location loc,
+                            Value transformedFilter, Value transformedInput) {
+  Value collapseFilter = collapseData(rewriter, loc, transformedFilter);
+  Value collapseInput = collapseData(rewriter, loc, transformedInput);
+
+  // Batched matrix multiply.
+  auto filterType = cast<ShapedType>(transformedFilter.getType());
+  ArrayRef<int64_t> filterShape = filterType.getShape();
+  auto inputType = cast<ShapedType>(transformedInput.getType());
+  Type inputElemType = inputType.getElementType();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+
+  auto matmulType = RankedTensorType::get(
+      {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3],
+       inputShape[4], filterShape[5]},
+      inputElemType);
----------------
Max191 wrote:
In the below comment, I believe the filter transform shape is not quite right, which may be why the matmul shape is batched in this way, but I would expect this to be:
```suggestion
  auto matmulType = RankedTensorType::get(
      {inputShape[2] * inputShape[3],
       inputShape[0] * inputShape[1] * inputShape[4],
       filterShape[5]},
      inputElemType);
```
The input transform result shape is `(tileH, tileW, alphaH, alphaW, inputN, inputC)`, and filter transform shape that I would expect to see is `(alphaH, alphaW, filterC, filterF)`. The shared dimensions in this case are `alphaH` and `alphaW`, so those should be batch. `tileH` and `tileW` would be `M` dimensions of the batch matmul.
An additional suggestion based on this is to change the layout of the input transform result from `(tileH, tileW, alphaH, alphaW, inputN, inputC)` to `(alphaH, alphaW, tileH, tileW, inputN, inputC)`, in order to have batch dimensions be outermost on the batch_matmul. This can help with performance of the matmul in some cases, and it means you can use the `linalg.batch_matmul` named op instead of a `linalg.generic` op. With this different layout, the matmul shape would be:
```suggestion
  auto matmulType = RankedTensorType::get(
      {inputShape[0] * inputShape[1],
       inputShape[2] * inputShape[3] * inputShape[4],
       filterShape[5]},
      inputElemType);
```
https://github.com/llvm/llvm-project/pull/96181
    
    
More information about the Mlir-commits
mailing list