[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)
Jack Frankland
llvmlistbot at llvm.org
Tue Oct 17 08:25:19 PDT 2023
================
@@ -0,0 +1,116 @@
+//===- TransposeConv2D.cpp - Convoultion transposition -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSECONV2D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+// Convolution converter that matches linalg.conv_2d_nhwc_fhwc and
+// linalg.conv_2d_nhwc_fhwc_q to linalg.transpose + linalg.conv_2d_nhwc_hwcf and
+// linalg.tranpose + linalg.conv_2d_nhwc_hwcf_q respectively.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+ using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(FHWCConvOp op,
+ PatternRewriter &rewriter) const final {
+ // Transpose the weights.
+ //
+ // To do this we first need to construct a permutation of the weight tensor
+ // dimensions. For a 2D convolution this will be known statically as [1, 2,
+ // 3, 0] however we construct the vector dynamically to future proof this
+ // logic so it can be extended to convolutions of higher dimensions.
+ auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+ auto weightPerm = SmallVector<int64_t>(resultTy.getRank() - 1);
+ std::iota(std::begin(weightPerm), std::end(weightPerm), 1);
+ weightPerm.push_back(0);
+
+ // Create the type for the transposed weight tensor since this will be
+ // different from the original weight type.
+ auto weight = op->getOperand(1);
+ auto weightTy = cast<ShapedType>(weight.getType());
+ auto newWeightShape = SmallVector<int64_t>(weightPerm.size());
+ std::generate(std::begin(newWeightShape), std::end(newWeightShape),
+ [dim = 0, &weightTy, &weightPerm]() mutable {
+ return weightTy.getShape()[weightPerm[dim++]];
+ });
+ auto newWeightTy =
+ RankedTensorType::get(newWeightShape, weightTy.getElementType());
+
+ // Because linalg.tranpose expects an "out" parameter we need to pass it a
+ // tensor of zeros of the result type so here we construct that tensor.
+ auto resultETy = resultTy.getElementType();
+ auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+ auto loc = op->getLoc();
+ auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+ loc, newWeightTy.getShape(), resultETy);
+ auto zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+ auto zeroTensor = rewriter
+ .create<linalg::FillOp>(loc, ValueRange{zero},
+ ValueRange{emptyTensor})
+ .result();
+
+ // We can then construct the transposition on our weights.
+ weight =
+ rewriter
+ .create<linalg::TransposeOp>(loc, weight, zeroTensor, weightPerm)
+ .getResult()[0];
+
+ // Create the convolution.
+ //
+ // The weights are always the second input argument.
+ auto newInputs = SmallVector<Value>{op.getInputs()};
+ newInputs[1] = weight;
+ rewriter.template replaceOpWithNewOp<HWCFConvOp>(
+ op, resultTy, newInputs, op.getOutputs(), op.getStrides(),
+ op.getDilations());
+ return success();
+ }
+};
+
+// This pass converts NHWC Conv2D operations with FHWC channel orderings to NHWC
+// Conv2D operations with HWCF channel orderings.
----------------
FranklandJack wrote:
Good idea, have added to the tablegen def.
https://github.com/llvm/llvm-project/pull/68567
More information about the Mlir-commits
mailing list