[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)

Jack Frankland llvmlistbot at llvm.org
Fri Oct 20 07:55:52 PDT 2023


================
@@ -0,0 +1,127 @@
+//===- TransposeConv2D.cpp - Convoultion transposition  -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSECONV2D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+// Convolution converter that applies the following rewrite:
+//
+// clang-format off
+// Before:
+//
+//   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+//                                               strides = dense<2> : tensor<2xi64>}
+//      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+//     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+//
+// After:
+//
+//    %cst = arith.constant 0.000000e+00 : f32
+//    %0 = tensor.empty() : tensor<2x2x6x8xf32>
+//    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+//    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
+//                  permutation = [1, 2, 3, 0]
+//    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+//         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
+//         -> tensor<1x2x2x8xf32>
+//
+// clang-format on
+// with an analogous example for the quantized case.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+  using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FHWCConvOp op,
+                                PatternRewriter &rewriter) const final {
+    // Construct a permutation of the filter tensor dimensions. For a 2D
+    // convolution this will be known statically as [1, 2, 3, 0].
+    auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+    auto filterPerm = SmallVector<int64_t>({1, 2, 3, 0});
+
+    // Create the type for the transposed filter tensor since this will be
+    // different from the original filter type.
+    auto filter = op->getOperand(1);
+    auto filterTy = cast<ShapedType>(filter.getType());
+    auto newfilterShape = SmallVector<int64_t>(filterPerm.size());
+    std::generate(std::begin(newfilterShape), std::end(newfilterShape),
+                  [dim = 0, &filterTy, &filterPerm]() mutable {
+                    return filterTy.getShape()[filterPerm[dim++]];
+                  });
+    auto newFilterTy =
+        RankedTensorType::get(newfilterShape, filterTy.getElementType());
----------------
FranklandJack wrote:

Oops sorry for the delay, this required more changes than I thought. I've updated the pass to handle both tensors and memrefs now. This requires a bit of branching to accomdate the different dialects e.g. tensor.empty vs. memref.alloc and the passing by out parameter that seems to be required for the memrefs. I've also simplified the transformation slightly - I realized you don't actually need a tensor/memref of zeros since you just overwrite it anyway so you can allocate it then pass it directly to the transpose. 

https://github.com/llvm/llvm-project/pull/68567


More information about the Mlir-commits mailing list