[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Oct 19 02:32:49 PDT 2023


================
@@ -0,0 +1,127 @@
+//===- TransposeConv2D.cpp - Convoultion transposition  -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSECONV2D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+// Convolution converter that applies the following rewrite:
+//
+// clang-format off
+// Before:
+//
+//   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+//                                               strides = dense<2> : tensor<2xi64>}
+//      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+//     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+//
+// After:
+//
+//    %cst = arith.constant 0.000000e+00 : f32
+//    %0 = tensor.empty() : tensor<2x2x6x8xf32>
+//    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+//    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
+//                  permutation = [1, 2, 3, 0]
+//    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+//         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
+//         -> tensor<1x2x2x8xf32>
+//
+// clang-format on
+// with an analogous example for the quantized case.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+  using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FHWCConvOp op,
+                                PatternRewriter &rewriter) const final {
+    // Construct a permutation of the filter tensor dimensions. For a 2D
+    // convolution this will be known statically as [1, 2, 3, 0].
+    auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+    auto filterPerm = SmallVector<int64_t>({1, 2, 3, 0});
+
+    // Create the type for the transposed filter tensor since this will be
+    // different from the original filter type.
+    auto filter = op->getOperand(1);
+    auto filterTy = cast<ShapedType>(filter.getType());
+    auto newfilterShape = SmallVector<int64_t>(filterPerm.size());
+    std::generate(std::begin(newfilterShape), std::end(newfilterShape),
+                  [dim = 0, &filterTy, &filterPerm]() mutable {
+                    return filterTy.getShape()[filterPerm[dim++]];
+                  });
+    auto newFilterTy =
+        RankedTensorType::get(newfilterShape, filterTy.getElementType());
+
+    // Because linalg.tranpose expects an "out" parameter we need to pass it a
+    // tensor of zeros of the result type so here we construct that tensor.
+    auto resultETy = resultTy.getElementType();
+    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto loc = op->getLoc();
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, newFilterTy.getShape(), resultETy);
+    auto zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+    auto zeroTensor = rewriter
+                          .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                  ValueRange{emptyTensor})
+                          .result();
+
+    // We can then construct the transposition on our filter.
+    filter =
+        rewriter
+            .create<linalg::TransposeOp>(loc, filter, zeroTensor, filterPerm)
+            .getResult()[0];
+
+    // Create the convolution.
+    auto newInputs = SmallVector<Value>{op.getInputs()};
+    // The filter are always the second input argument.
+    newInputs[1] = filter;
+    rewriter.template replaceOpWithNewOp<HWCFConvOp>(
----------------
MacDue wrote:

I don't think the `.template` syntax should be necessary here? The `PatternRewriter` is not a template parameter. 

https://github.com/llvm/llvm-project/pull/68567


More information about the Mlir-commits mailing list