[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)

Jack Frankland llvmlistbot at llvm.org
Thu Oct 19 06:28:51 PDT 2023

@@ -0,0 +1,127 @@
+//===- TransposeConv2D.cpp - Convoultion transposition  -------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+namespace mlir {
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+using namespace mlir;
+namespace {
+// Convolution converter that applies the following rewrite:
+// clang-format off
+// Before:
+//   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+//                                               strides = dense<2> : tensor<2xi64>}
+//      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+//     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+// After:
+//    %cst = arith.constant 0.000000e+00 : f32
+//    %0 = tensor.empty() : tensor<2x2x6x8xf32>
+//    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+//    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
+//                  permutation = [1, 2, 3, 0]
+//    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+//         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
+//         -> tensor<1x2x2x8xf32>
+// clang-format on
+// with an analogous example for the quantized case.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+  using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FHWCConvOp op,
+                                PatternRewriter &rewriter) const final {
+    // Construct a permutation of the filter tensor dimensions. For a 2D
+    // convolution this will be known statically as [1, 2, 3, 0].
+    auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+    auto filterPerm = SmallVector<int64_t>({1, 2, 3, 0});
+    // Create the type for the transposed filter tensor since this will be
+    // different from the original filter type.
+    auto filter = op->getOperand(1);
+    auto filterTy = cast<ShapedType>(filter.getType());
+    auto newfilterShape = SmallVector<int64_t>(filterPerm.size());
+    std::generate(std::begin(newfilterShape), std::end(newfilterShape),
+                  [dim = 0, &filterTy, &filterPerm]() mutable {
+                    return filterTy.getShape()[filterPerm[dim++]];
+                  });
+    auto newFilterTy =
+        RankedTensorType::get(newfilterShape, filterTy.getElementType());
FranklandJack wrote:

Also, it looks like:
func.func @conv_2d_nhwc_fhwc(%arg0:
  %arg1: memref<?x?x?x?xf32>,
  %arg2: memref<?x?x?x?xf32>
) -> memref<?x?x?x?xf32> {
  linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
    ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
    outs(%arg2 : memref<?x?x?x?xf32>)
  return %arg2 : memref<?x?x?x?xf32>
round trips fine, but:
func.func @conv_2d_nhwc_fhwc(%arg0:
  %arg1: memref<?x?x?x?xf32>,
  %arg2: memref<?x?x?x?xf32>
) -> memref<?x?x?x?xf32> {
  linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
    ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
    outs(%arg2 : memref<?x?x?x?xf32>) -> <?x?x?x?xf32>
  return %arg2 : memref<?x?x?x?xf32>
fails to validate (note the trailing return type on the conv).

Looking at the other LinAlg tests it seems like the trailing return type is valid syntax for the operator, so I'm inclined to say this a bug.

For now I'll omit it from my lit test so we can get this merged but maybe I should file a bug for this?


