[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)

Jack Frankland llvmlistbot at llvm.org
Tue Oct 17 08:24:47 PDT 2023


================
@@ -0,0 +1,116 @@
+//===- TransposeConv2D.cpp - Convoultion transposition  -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSECONV2D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+// Convolution converter that matches linalg.conv_2d_nhwc_fhwc and
+// linalg.conv_2d_nhwc_fhwc_q to linalg.transpose + linalg.conv_2d_nhwc_hwcf and
+// linalg.tranpose + linalg.conv_2d_nhwc_hwcf_q respectively.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+  using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FHWCConvOp op,
+                                PatternRewriter &rewriter) const final {
+    // Transpose the weights.
+    //
+    // To do this we first need to construct a permutation of the weight tensor
+    // dimensions. For a 2D convolution this will be known statically as [1, 2,
+    // 3, 0] however we construct the vector dynamically to future proof this
+    // logic so it can be extended to convolutions of higher dimensions.
+    auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+    auto weightPerm = SmallVector<int64_t>(resultTy.getRank() - 1);
+    std::iota(std::begin(weightPerm), std::end(weightPerm), 1);
+    weightPerm.push_back(0);
+
+    // Create the type for the transposed weight tensor since this will be
+    // different from the original weight type.
+    auto weight = op->getOperand(1);
+    auto weightTy = cast<ShapedType>(weight.getType());
+    auto newWeightShape = SmallVector<int64_t>(weightPerm.size());
+    std::generate(std::begin(newWeightShape), std::end(newWeightShape),
+                  [dim = 0, &weightTy, &weightPerm]() mutable {
+                    return weightTy.getShape()[weightPerm[dim++]];
+                  });
+    auto newWeightTy =
+        RankedTensorType::get(newWeightShape, weightTy.getElementType());
+
+    // Because linalg.tranpose expects an "out" parameter we need to pass it a
+    // tensor of zeros of the result type so here we construct that tensor.
+    auto resultETy = resultTy.getElementType();
+    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto loc = op->getLoc();
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, newWeightTy.getShape(), resultETy);
+    auto zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+    auto zeroTensor = rewriter
+                          .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                  ValueRange{emptyTensor})
+                          .result();
+
+    // We can then construct the transposition on our weights.
+    weight =
+        rewriter
+            .create<linalg::TransposeOp>(loc, weight, zeroTensor, weightPerm)
+            .getResult()[0];
+
+    // Create the convolution.
+    //
+    // The weights are always the second input argument.
----------------
FranklandJack wrote:

Good point, updated! 

https://github.com/llvm/llvm-project/pull/68567


More information about the Mlir-commits mailing list