[Mlir-commits] [mlir] 54eec7c - [mlir][tosa] Separate tosa.transpose_conv decomposition and added stride support
Rob Suderman
llvmlistbot at llvm.org
Tue Nov 23 12:22:30 PST 2021
Author: Rob Suderman
Date: 2021-11-23T12:16:44-08:00
New Revision: 54eec7cafc396f3d1444aacf4f1ed71fceb4e503
URL: https://github.com/llvm/llvm-project/commit/54eec7cafc396f3d1444aacf4f1ed71fceb4e503
DIFF: https://github.com/llvm/llvm-project/commit/54eec7cafc396f3d1444aacf4f1ed71fceb4e503.diff
LOG: [mlir][tosa] Separate tosa.transpose_conv decomposition and added stride support
Transpose convolution decomposition is now performed in a separate pass. This
allows padding / constant propagation to be performed at the TOSA level. It
also adds support for striding when there is no dilation.
Differential Revision: https://reviews.llvm.org/D114409
Added:
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
Modified:
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index b00b161aef156..278402eb93b01 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -19,6 +19,7 @@
namespace mlir {
namespace tosa {
+std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index dfa7b1f8582e3..7d6af621675b8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -15,6 +15,21 @@
include "mlir/Pass/PassBase.td"
+def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> {
+ let summary = "Deompose transpose convolutiions into standard convolutions.";
+ let description = [{
+ Pass that uses shape manipulation and convolution operations to transform
+ a transpose convolution into a regular convolution.
+ }];
+
+ let constructor = "createTosaDecomposeTransposeConvPass()";
+ let dependentDialects = [
+ "StandardOpsDialect",
+ "tensor::TensorDialect",
+ "tosa::TosaDialect",
+ ];
+}
+
def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
let summary = "Propagate shapes across TOSA operations";
let description = [{
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f4470d20fca4c..77cf563abe1a1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1384,77 +1384,6 @@ class DepthwiseConvConverter
}
};
-class TransposeConvConverter
- : public OpConversionPattern<tosa::TransposeConv2DOp> {
-public:
- using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- Location loc = op->getLoc();
- Value input = op->getOperand(0);
- Value weight = op->getOperand(1);
- Value bias = op->getOperand(2);
-
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType weightTy = weight.getType().cast<ShapedType>();
- ShapedType biasTy = bias.getType().cast<ShapedType>();
- ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
-
- llvm::SmallVector<int64_t> pad;
- llvm::SmallVector<int64_t> stride;
- llvm::SmallVector<int64_t> dilation;
-
- getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
- getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
- getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
-
- // If striding is all 1 we can modify padding and reverse the kernel along
- // the x/y direction to make it a regular convolution. This is much simpler
- // then handling striding....
- if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) {
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
- return failure();
-
- int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
- int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
- int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
- int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
-
- llvm::SmallVector<int64_t> convPad(4, 0);
- convPad[0] = kernelHeight - 1 - pad[0];
- convPad[2] = kernelWidth - 1 - pad[1];
- convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
- convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
-
- auto reverse1 = rewriter.create<tosa::ReverseOp>(
- loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
- auto reverse2 = rewriter.create<tosa::ReverseOp>(
- loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
-
- Value conv2d;
- if (op.quantization_info().hasValue()) {
- conv2d = rewriter.create<tosa::Conv2DOp>(
- loc, resultTy, input, reverse2, bias,
- rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
- rewriter.getI64ArrayAttr(dilation),
- op.quantization_info().getValue());
- } else {
- conv2d = rewriter.create<tosa::Conv2DOp>(
- loc, resultTy, input, reverse2, bias,
- rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
- rewriter.getI64ArrayAttr(dilation));
- }
-
- rewriter.replaceOp(op, conv2d);
- return success();
- }
-
- return failure();
- }
-};
-
class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
public:
using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
@@ -3188,7 +3117,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
ConcatConverter,
ConvConverter,
DepthwiseConvConverter,
- TransposeConvConverter,
GatherConverter,
PadConverter,
ReshapeConverterCollapse,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 11aab5828cfd0..335fdfadcab4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -50,6 +50,7 @@ struct TosaToLinalg : public TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::IfOp>();
target.addLegalOp<tosa::ConstOp>();
target.addLegalOp<tosa::WhileOp>();
+ target.addLegalOp<tosa::SliceOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index f466b1ab85389..b5e90bbeecc59 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaDecomposeTransposeConv.cpp
TosaInferShapes.cpp
TosaMakeBroadcastable.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
new file mode 100644
index 0000000000000..c1fcca2d27e57
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -0,0 +1,390 @@
+//===- TosaDecomposeTransposeConv.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Insert reshape to binary op's input if needed to match rank
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR//TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+template <typename T>
+static void getValuesFromIntArrayAttribute(ArrayAttr attr,
+ SmallVector<T> &arrayValues) {
+ for (Attribute val : attr.getValue()) {
+ arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+ }
+}
+
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
+ Args &&...args) {
+ auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
+
+ InferShapedTypeOpInterface shapeInterface =
+ dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+ if (!shapeInterface)
+ return op;
+
+ SmallVector<ShapedTypeComponents> returnedShapes;
+ if (shapeInterface
+ .inferReturnTypeComponents(op.getContext(), op.getLoc(),
+ op->getOperands(), op->getAttrDictionary(),
+ op->getRegions(), returnedShapes)
+ .failed())
+ return op;
+
+ // We need to use the element type of the existing result type to generate
+ // the new result shaped type. This is because rescale can include a cast to
+ //
diff erent bit-width types and does not have a TypeAttr to define the
+ // target type.
+ auto result = op->getResult(0);
+ auto predictedShape = returnedShapes[0];
+ auto currentKnowledge =
+ mlir::tosa::ValueKnowledge::getKnowledgeFromType(result_ty);
+
+ // Compute the knowledge based on the inferred type.
+ auto inferredKnowledge =
+ mlir::tosa::ValueKnowledge::getPessimisticValueState();
+ inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
+ inferredKnowledge.hasRank = predictedShape.hasRank();
+ if (predictedShape.hasRank()) {
+ for (auto dim : predictedShape.getDims()) {
+ inferredKnowledge.sizes.push_back(dim);
+ }
+ }
+
+ // Compute the new type based on the joined version.
+ auto newKnowledge =
+ mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+ auto new_ty = newKnowledge.getType();
+ result.setType(new_ty);
+ return op;
+}
+
+class TransposeConvDilatedConverter
+ : public OpRewritePattern<tosa::TransposeConv2DOp> {
+public:
+ using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op->getLoc();
+ Value input = op->getOperand(0);
+ Value weight = op->getOperand(1);
+ Value bias = op->getOperand(2);
+
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType weightTy = weight.getType().cast<ShapedType>();
+ ShapedType biasTy = bias.getType().cast<ShapedType>();
+ ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+
+ llvm::SmallVector<int64_t> pad;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> dilation;
+
+ getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
+ getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
+ getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
+
+ // If striding is all 1 we can modify padding and reverse the kernel along
+ // the x/y direction to make it a regular convolution. This is much simpler
+ // then handling striding....
+ if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
+ return failure();
+
+ if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+ !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ return failure();
+
+ int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
+ int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
+ int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
+ int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
+
+ llvm::SmallVector<int64_t> convPad(4, 0);
+ convPad[0] = kernelHeight - 1 - pad[0];
+ convPad[2] = kernelWidth - 1 - pad[1];
+ convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
+ convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
+
+ auto reverse1 = rewriter.create<tosa::ReverseOp>(
+ loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
+ auto reverse2 = rewriter.create<tosa::ReverseOp>(
+ loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
+
+ Value conv2d;
+ if (op.quantization_info().hasValue()) {
+ conv2d = rewriter.create<tosa::Conv2DOp>(
+ loc, resultTy, input, reverse2, bias,
+ rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+ rewriter.getI64ArrayAttr(dilation),
+ op.quantization_info().getValue());
+ } else {
+ conv2d = rewriter.create<tosa::Conv2DOp>(
+ loc, resultTy, input, reverse2, bias,
+ rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+ rewriter.getI64ArrayAttr(dilation));
+ }
+
+ rewriter.replaceOp(op, conv2d);
+ return success();
+ }
+};
+
+class TransposeConvStridedConverter
+ : public OpRewritePattern<tosa::TransposeConv2DOp> {
+public:
+ using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op->getLoc();
+ Value input = op->getOperand(0);
+ Value weight = op->getOperand(1);
+ Value bias = op->getOperand(2);
+
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType weightTy = weight.getType().cast<ShapedType>();
+ ShapedType biasTy = bias.getType().cast<ShapedType>();
+ ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+
+ Type inputETy = inputTy.getElementType();
+ Type weightETy = weightTy.getElementType();
+ Type biasETy = biasTy.getElementType();
+ Type resultETy = resultTy.getElementType();
+
+ llvm::SmallVector<int64_t> pad;
+ llvm::SmallVector<int64_t> stride;
+ llvm::SmallVector<int64_t> dilation;
+
+ getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
+ getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
+ getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
+
+ // If striding is all 1 we can modify padding and reverse the kernel along
+ // the x/y direction to make it a regular convolution. This is much simpler
+ // then handling striding....
+ if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
+ return failure();
+
+ // If strides are all 1 we dont need to use this one.
+ if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
+ return failure();
+
+ if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+ !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ return failure();
+
+ int64_t batch = inputTy.getDimSize(0);
+
+ int64_t outputChannels = weightTy.getDimSize(0);
+ int64_t weightHeight = weightTy.getDimSize(1);
+ int64_t weightWidth = weightTy.getDimSize(2);
+ int64_t inputChannels = weightTy.getDimSize(3);
+
+ // Pad the weight so that it is modulo of the striding.
+ llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+ weightPadding[3] =
+ weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
+ weightPadding[5] =
+ weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
+ DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
+ RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
+ Value weightPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
+ rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
+
+ if (op.quantization_info().hasValue()) {
+ auto quantInfo = op.quantization_info().getValue();
+ weight = CreateOpAndInfer<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ weightPaddingVal, nullptr,
+ PadOpQuantizationAttr::get(quantInfo.weight_zp(),
+ rewriter.getContext()));
+
+ } else {
+ weight = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
+ UnrankedTensorType::get(weightETy),
+ weight, weightPaddingVal);
+ }
+
+ weightTy = weight.getType().cast<ShapedType>();
+ weightHeight = weightTy.getDimSize(1);
+ weightWidth = weightTy.getDimSize(2);
+
+ // Split out the width / height by the stride dimensions.
+ llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
+ outputChannels, weightHeight / stride[0],
+ stride[0], weightWidth / stride[1],
+ stride[1], inputChannels};
+ weight = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64ArrayAttr(weightReshapeDims0));
+
+ // Transpose the factored-out stride to the output channels.
+ Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
+ loc, RankedTensorType::get({6}, rewriter.getI32Type()),
+ rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
+
+ weight = CreateOpAndInfer<tosa::TransposeOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ transposeWeightVal);
+
+ // Collapse the strides and output channels into a single dimension.
+ llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+ outputChannels * stride[0] * stride[1], weightHeight / stride[0],
+ weightWidth / stride[1], inputChannels};
+ weight = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64ArrayAttr(weightReshapeDims1));
+ ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
+
+ weight = CreateOpAndInfer<tosa::ReverseOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64IntegerAttr(1));
+ weight = CreateOpAndInfer<tosa::ReverseOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ rewriter.getI64IntegerAttr(2));
+
+ // We need to pad the input far enough that we can pull all values.
+ llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+ inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
+ inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
+ inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
+ inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
+
+ DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
+ RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
+
+ Value inputPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
+ rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
+
+ if (op.quantization_info().hasValue()) {
+ auto quantInfo = op.quantization_info().getValue();
+ input = CreateOpAndInfer<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(inputETy), input,
+ inputPaddingVal, nullptr,
+ PadOpQuantizationAttr::get(quantInfo.input_zp(),
+ rewriter.getContext()));
+ } else {
+ input = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
+ UnrankedTensorType::get(inputETy),
+ input, inputPaddingVal);
+ }
+
+ // We use a zero bias as we need to broadcast the bias.
+ auto zeroBias = rewriter.create<tosa::ConstOp>(
+ loc,
+ RankedTensorType::get({outputChannels * stride[0] * stride[1]},
+ biasETy),
+ DenseElementsAttr::get(
+ RankedTensorType::get({outputChannels * stride[0] * stride[1]},
+ biasETy),
+ rewriter.getZeroAttr(biasETy)));
+
+ // Perform the convolution using the zero bias.
+ Value conv2d;
+ if (op.quantization_info().hasValue()) {
+ conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), input,
+ weight, zeroBias,
+ /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
+ /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
+ /*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
+ op.quantization_info().getValue())
+ .getResult();
+ } else {
+ conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), input,
+ weight, zeroBias,
+ /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
+ /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
+ /*dilation=*/rewriter.getI64ArrayAttr({1, 1}))
+ .getResult();
+ }
+
+ // Factor the resulting width / height.
+ ShapedType convTy = conv2d.getType().cast<ShapedType>();
+ Type convETy = convTy.getElementType();
+
+ int64_t convHeight = convTy.getDimSize(1);
+ int64_t convWidth = convTy.getDimSize(2);
+
+ // Factor striding out of the convolution result.
+ llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
+ batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+ conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+ rewriter.getI64ArrayAttr(convReshapeDims0));
+
+ // Transpose the factored-out stride to the output channels.
+ Value transposeConvVal = rewriter.create<tosa::ConstOp>(
+ loc, RankedTensorType::get({6}, rewriter.getI32Type()),
+ rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
+
+ conv2d = CreateOpAndInfer<tosa::TransposeOp>(
+ rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
+ transposeConvVal);
+
+ // Fuse striding behavior back into width / height.
+ llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
+ batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+ conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+ rewriter.getI64ArrayAttr(convReshapeDims1));
+
+ // Slice out the final result.
+ llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
+ llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
+ resultTy.getShape().begin());
+ sliceBegin[1] = pad[0];
+ sliceBegin[2] = pad[1];
+
+ auto slice = CreateOpAndInfer<tosa::SliceOp>(
+ rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+ rewriter.getI64ArrayAttr(sliceBegin),
+ rewriter.getI64ArrayAttr(resultTy.getShape()))
+ .getResult();
+
+ auto addBias =
+ CreateOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
+
+ rewriter.replaceOp(op, addBias.getResult());
+
+ return success();
+ }
+};
+
+/// Pass that enables broadcast by making all input arrays have the same
+/// number of dimensions. Insert RESHAPE operations to lower rank operand
+struct TosaDecomposeTransposeConv
+ : public TosaDecomposeTransposeConvBase<TosaDecomposeTransposeConv> {
+public:
+ void runOnFunction() override {
+ auto func = getFunction();
+ RewritePatternSet patterns(func.getContext());
+ patterns
+ .insert<TransposeConvDilatedConverter, TransposeConvStridedConverter>(
+ func.getContext());
+ (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+ }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaDecomposeTransposeConvPass() {
+ return std::make_unique<TosaDecomposeTransposeConv>();
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 2e25ad975a09e..1cf88f9bc9709 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1719,27 +1719,6 @@ func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tenso
return
}
-// -----
-
-// CHECK-LABEL: @transpose_conv
-func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
- // CHECK: linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]
- // CHECK: linalg.conv_2d_nhwc_hwcf
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 14, 14, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x14x14x4xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_conv_dilated
-func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
- // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0]
- // CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<3x3x2x4xf32>)
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32>
- return
-}
-
-
// -----
// CHECK-LABEL: @resize_nearest
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
new file mode 100644
index 0000000000000..627622ba796e3
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s
+
+// CHECK-LABEL: @transpose_conv2d
+func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+ // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+ // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+ // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]}
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
+ %1 = tensor.cast %0 : tensor<2x18x19x5xf32> to tensor<2x?x?x5xf32>
+ return %1 : tensor<2x?x?x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_quantized
+func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
+ // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+ // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+ // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]}
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
+ return %0 : tensor<2x18x19x5xi32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_dilated
+func @transpose_conv2d_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+ // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+ // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+ // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [2, 3], pad = [4, 4, 15, 15], stride = [1, 1]}
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x20x29x5xf32>
+ %1 = tensor.cast %0 : tensor<2x20x29x5xf32> to tensor<2x?x?x5xf32>
+ return %1 : tensor<2x?x?x5xf32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_strided
+func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+ // Manipulate the weight matrix to handle striding.
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]])
+ // CHECK-DAG: %[[RESW1:.+]] = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]}
+ // CHECK-DAG: %[[TRANS:.+]] = "tosa.transpose"(%[[RESW1]], %[[TRANSV]])
+ // CHECK-DAG: %[[RESW2:.+]] = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]}
+ // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64}
+ // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+
+ // Pad out the input matrix to handle the transpose conv.
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]])
+
+ // Manipulate the final shape.
+ // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<30xf32>}
+ // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]}
+ // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]}
+ // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
+ // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
+ // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
+ %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
+ return %1 : tensor<2x?x?x5xf32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_strided_quantized
+func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
+ // Manipulate the weight matrix to handle striding.
+ // CHECK-DAG: %[[PADV:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[PADW:.+]] = "tosa.pad"(%arg1, %[[PADV]]) {quantization_info = {input_zp = 42 : i32}}
+ // CHECK-DAG: %[[RESW1:.+]] = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]}
+ // CHECK-DAG: %[[TRANS:.+]] = "tosa.transpose"(%[[RESW1]], %[[TRANSV]])
+ // CHECK-DAG: %[[RESW2:.+]] = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]}
+ // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64}
+ // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+
+ // Pad out the input matrix to handle the transpose conv.
+ // CHECK-DAG: %[[PAD:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+ // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+ // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]]) {quantization_info = {input_zp = -22 : i32}}
+
+ // Manipulate the final shape.
+ // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() {value = dense<0> : tensor<30xi32>}
+ // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]}
+ // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]}
+ // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
+ // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
+ // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+ return %0 : tensor<2x35x47x5xi32>
+}
More information about the Mlir-commits
mailing list