[Mlir-commits] [mlir] 54eec7c - [mlir][tosa] Separate tosa.transpose_conv decomposition and added stride support

Rob Suderman llvmlistbot at llvm.org
Tue Nov 23 12:22:30 PST 2021


Author: Rob Suderman
Date: 2021-11-23T12:16:44-08:00
New Revision: 54eec7cafc396f3d1444aacf4f1ed71fceb4e503

URL: https://github.com/llvm/llvm-project/commit/54eec7cafc396f3d1444aacf4f1ed71fceb4e503
DIFF: https://github.com/llvm/llvm-project/commit/54eec7cafc396f3d1444aacf4f1ed71fceb4e503.diff

LOG: [mlir][tosa] Separate tosa.transpose_conv decomposition and added stride support

Transpose convolution decomposition is now performed in a separate pass. This
allows padding / constant propagation to be performed at the TOSA level. It
also adds support for striding when there is no dilation.

Differential Revision: https://reviews.llvm.org/D114409

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
    mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index b00b161aef156..278402eb93b01 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -19,6 +19,7 @@
 namespace mlir {
 namespace tosa {
 
+std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
 std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index dfa7b1f8582e3..7d6af621675b8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -15,6 +15,21 @@
 
 include "mlir/Pass/PassBase.td"
 
+def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> {
+  let summary = "Deompose transpose convolutiions into standard convolutions.";
+  let description = [{
+    Pass that uses shape manipulation and convolution operations to transform
+    a transpose convolution into a regular convolution.
+  }];
+
+  let constructor = "createTosaDecomposeTransposeConvPass()";
+  let dependentDialects = [
+    "StandardOpsDialect",
+    "tensor::TensorDialect",
+    "tosa::TosaDialect",
+  ];
+}
+
 def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
   let summary = "Propagate shapes across TOSA operations";
   let description = [{

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f4470d20fca4c..77cf563abe1a1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1384,77 +1384,6 @@ class DepthwiseConvConverter
   }
 };
 
-class TransposeConvConverter
-    : public OpConversionPattern<tosa::TransposeConv2DOp> {
-public:
-  using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    Location loc = op->getLoc();
-    Value input = op->getOperand(0);
-    Value weight = op->getOperand(1);
-    Value bias = op->getOperand(2);
-
-    ShapedType inputTy = input.getType().cast<ShapedType>();
-    ShapedType weightTy = weight.getType().cast<ShapedType>();
-    ShapedType biasTy = bias.getType().cast<ShapedType>();
-    ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
-
-    llvm::SmallVector<int64_t> pad;
-    llvm::SmallVector<int64_t> stride;
-    llvm::SmallVector<int64_t> dilation;
-
-    getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
-    getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
-    getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
-
-    // If striding is all 1 we can modify padding and reverse the kernel along
-    // the x/y direction to make it a regular convolution. This is much simpler
-    // then handling striding....
-    if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) {
-      if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-          !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
-        return failure();
-
-      int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
-      int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
-      int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
-      int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
-
-      llvm::SmallVector<int64_t> convPad(4, 0);
-      convPad[0] = kernelHeight - 1 - pad[0];
-      convPad[2] = kernelWidth - 1 - pad[1];
-      convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
-      convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
-
-      auto reverse1 = rewriter.create<tosa::ReverseOp>(
-          loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
-      auto reverse2 = rewriter.create<tosa::ReverseOp>(
-          loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
-
-      Value conv2d;
-      if (op.quantization_info().hasValue()) {
-        conv2d = rewriter.create<tosa::Conv2DOp>(
-            loc, resultTy, input, reverse2, bias,
-            rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
-            rewriter.getI64ArrayAttr(dilation),
-            op.quantization_info().getValue());
-      } else {
-        conv2d = rewriter.create<tosa::Conv2DOp>(
-            loc, resultTy, input, reverse2, bias,
-            rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
-            rewriter.getI64ArrayAttr(dilation));
-      }
-
-      rewriter.replaceOp(op, conv2d);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
 public:
   using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
@@ -3188,7 +3117,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       ConcatConverter,
       ConvConverter,
       DepthwiseConvConverter,
-      TransposeConvConverter,
       GatherConverter,
       PadConverter,
       ReshapeConverterCollapse,

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 11aab5828cfd0..335fdfadcab4c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -50,6 +50,7 @@ struct TosaToLinalg : public TosaToLinalgBase<TosaToLinalg> {
     target.addLegalOp<tosa::IfOp>();
     target.addLegalOp<tosa::ConstOp>();
     target.addLegalOp<tosa::WhileOp>();
+    target.addLegalOp<tosa::SliceOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index f466b1ab85389..b5e90bbeecc59 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRTosaTransforms
+  TosaDecomposeTransposeConv.cpp
   TosaInferShapes.cpp
   TosaMakeBroadcastable.cpp
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
new file mode 100644
index 0000000000000..c1fcca2d27e57
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -0,0 +1,390 @@
+//===- TosaDecomposeTransposeConv.cpp
+//------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Insert reshape to binary op's input if needed to match rank
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR//TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+template <typename T>
+static void getValuesFromIntArrayAttribute(ArrayAttr attr,
+                                           SmallVector<T> &arrayValues) {
+  for (Attribute val : attr.getValue()) {
+    arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+  }
+}
+
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
+                        Args &&...args) {
+  auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
+
+  InferShapedTypeOpInterface shapeInterface =
+      dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+  if (!shapeInterface)
+    return op;
+
+  SmallVector<ShapedTypeComponents> returnedShapes;
+  if (shapeInterface
+          .inferReturnTypeComponents(op.getContext(), op.getLoc(),
+                                     op->getOperands(), op->getAttrDictionary(),
+                                     op->getRegions(), returnedShapes)
+          .failed())
+    return op;
+
+  // We need to use the element type of the existing result type to generate
+  // the new result shaped type. This is because rescale can include a cast to
+  // 
diff erent bit-width types and does not have a TypeAttr to define the
+  // target type.
+  auto result = op->getResult(0);
+  auto predictedShape = returnedShapes[0];
+  auto currentKnowledge =
+      mlir::tosa::ValueKnowledge::getKnowledgeFromType(result_ty);
+
+  // Compute the knowledge based on the inferred type.
+  auto inferredKnowledge =
+      mlir::tosa::ValueKnowledge::getPessimisticValueState();
+  inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
+  inferredKnowledge.hasRank = predictedShape.hasRank();
+  if (predictedShape.hasRank()) {
+    for (auto dim : predictedShape.getDims()) {
+      inferredKnowledge.sizes.push_back(dim);
+    }
+  }
+
+  // Compute the new type based on the joined version.
+  auto newKnowledge =
+      mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+  auto new_ty = newKnowledge.getType();
+  result.setType(new_ty);
+  return op;
+}
+
+class TransposeConvDilatedConverter
+    : public OpRewritePattern<tosa::TransposeConv2DOp> {
+public:
+  using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op->getLoc();
+    Value input = op->getOperand(0);
+    Value weight = op->getOperand(1);
+    Value bias = op->getOperand(2);
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType weightTy = weight.getType().cast<ShapedType>();
+    ShapedType biasTy = bias.getType().cast<ShapedType>();
+    ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+
+    llvm::SmallVector<int64_t> pad;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> dilation;
+
+    getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
+    getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
+    getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
+
+    // If striding is all 1 we can modify padding and reverse the kernel along
+    // the x/y direction to make it a regular convolution. This is much simpler
+    // then handling striding....
+    if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
+      return failure();
+
+    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+      return failure();
+
+    int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
+    int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
+    int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
+    int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
+
+    llvm::SmallVector<int64_t> convPad(4, 0);
+    convPad[0] = kernelHeight - 1 - pad[0];
+    convPad[2] = kernelWidth - 1 - pad[1];
+    convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
+    convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
+
+    auto reverse1 = rewriter.create<tosa::ReverseOp>(
+        loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
+    auto reverse2 = rewriter.create<tosa::ReverseOp>(
+        loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
+
+    Value conv2d;
+    if (op.quantization_info().hasValue()) {
+      conv2d = rewriter.create<tosa::Conv2DOp>(
+          loc, resultTy, input, reverse2, bias,
+          rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+          rewriter.getI64ArrayAttr(dilation),
+          op.quantization_info().getValue());
+    } else {
+      conv2d = rewriter.create<tosa::Conv2DOp>(
+          loc, resultTy, input, reverse2, bias,
+          rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+          rewriter.getI64ArrayAttr(dilation));
+    }
+
+    rewriter.replaceOp(op, conv2d);
+    return success();
+  }
+};
+
+class TransposeConvStridedConverter
+    : public OpRewritePattern<tosa::TransposeConv2DOp> {
+public:
+  using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op->getLoc();
+    Value input = op->getOperand(0);
+    Value weight = op->getOperand(1);
+    Value bias = op->getOperand(2);
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType weightTy = weight.getType().cast<ShapedType>();
+    ShapedType biasTy = bias.getType().cast<ShapedType>();
+    ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+
+    Type inputETy = inputTy.getElementType();
+    Type weightETy = weightTy.getElementType();
+    Type biasETy = biasTy.getElementType();
+    Type resultETy = resultTy.getElementType();
+
+    llvm::SmallVector<int64_t> pad;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> dilation;
+
+    getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
+    getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
+    getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
+
+    // If striding is all 1 we can modify padding and reverse the kernel along
+    // the x/y direction to make it a regular convolution. This is much simpler
+    // then handling striding....
+    if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
+      return failure();
+
+    // If strides are all 1 we dont need to use this one.
+    if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
+      return failure();
+
+    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+      return failure();
+
+    int64_t batch = inputTy.getDimSize(0);
+
+    int64_t outputChannels = weightTy.getDimSize(0);
+    int64_t weightHeight = weightTy.getDimSize(1);
+    int64_t weightWidth = weightTy.getDimSize(2);
+    int64_t inputChannels = weightTy.getDimSize(3);
+
+    // Pad the weight so that it is modulo of the striding.
+    llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    weightPadding[3] =
+        weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
+    weightPadding[5] =
+        weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
+    DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
+    Value weightPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
+        rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
+
+    if (op.quantization_info().hasValue()) {
+      auto quantInfo = op.quantization_info().getValue();
+      weight = CreateOpAndInfer<tosa::PadOp>(
+          rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+          weightPaddingVal, nullptr,
+          PadOpQuantizationAttr::get(quantInfo.weight_zp(),
+                                     rewriter.getContext()));
+
+    } else {
+      weight = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
+                                             UnrankedTensorType::get(weightETy),
+                                             weight, weightPaddingVal);
+    }
+
+    weightTy = weight.getType().cast<ShapedType>();
+    weightHeight = weightTy.getDimSize(1);
+    weightWidth = weightTy.getDimSize(2);
+
+    // Split out the width / height by the stride dimensions.
+    llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
+        outputChannels, weightHeight / stride[0],
+        stride[0],      weightWidth / stride[1],
+        stride[1],      inputChannels};
+    weight = CreateOpAndInfer<tosa::ReshapeOp>(
+        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+        rewriter.getI64ArrayAttr(weightReshapeDims0));
+
+    // Transpose the factored-out stride to the output channels.
+    Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
+        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
+        rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
+
+    weight = CreateOpAndInfer<tosa::TransposeOp>(
+        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+        transposeWeightVal);
+
+    // Collapse the strides and output channels into a single dimension.
+    llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+        outputChannels * stride[0] * stride[1], weightHeight / stride[0],
+        weightWidth / stride[1], inputChannels};
+    weight = CreateOpAndInfer<tosa::ReshapeOp>(
+        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+        rewriter.getI64ArrayAttr(weightReshapeDims1));
+    ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
+
+    weight = CreateOpAndInfer<tosa::ReverseOp>(
+        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+        rewriter.getI64IntegerAttr(1));
+    weight = CreateOpAndInfer<tosa::ReverseOp>(
+        rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+        rewriter.getI64IntegerAttr(2));
+
+    // We need to pad the input far enough that we can pull all values.
+    llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
+    inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
+    inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
+    inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
+
+    DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
+
+    Value inputPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
+        rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
+
+    if (op.quantization_info().hasValue()) {
+      auto quantInfo = op.quantization_info().getValue();
+      input = CreateOpAndInfer<tosa::PadOp>(
+          rewriter, loc, UnrankedTensorType::get(inputETy), input,
+          inputPaddingVal, nullptr,
+          PadOpQuantizationAttr::get(quantInfo.input_zp(),
+                                     rewriter.getContext()));
+    } else {
+      input = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
+                                            UnrankedTensorType::get(inputETy),
+                                            input, inputPaddingVal);
+    }
+
+    // We use a zero bias as we need to broadcast the bias.
+    auto zeroBias = rewriter.create<tosa::ConstOp>(
+        loc,
+        RankedTensorType::get({outputChannels * stride[0] * stride[1]},
+                              biasETy),
+        DenseElementsAttr::get(
+            RankedTensorType::get({outputChannels * stride[0] * stride[1]},
+                                  biasETy),
+            rewriter.getZeroAttr(biasETy)));
+
+    // Perform the convolution using the zero bias.
+    Value conv2d;
+    if (op.quantization_info().hasValue()) {
+      conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
+                   rewriter, loc, UnrankedTensorType::get(resultETy), input,
+                   weight, zeroBias,
+                   /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
+                   /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
+                   /*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
+                   op.quantization_info().getValue())
+                   .getResult();
+    } else {
+      conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
+                   rewriter, loc, UnrankedTensorType::get(resultETy), input,
+                   weight, zeroBias,
+                   /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
+                   /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
+                   /*dilation=*/rewriter.getI64ArrayAttr({1, 1}))
+                   .getResult();
+    }
+
+    // Factor the resulting width / height.
+    ShapedType convTy = conv2d.getType().cast<ShapedType>();
+    Type convETy = convTy.getElementType();
+
+    int64_t convHeight = convTy.getDimSize(1);
+    int64_t convWidth = convTy.getDimSize(2);
+
+    // Factor striding out of the convolution result.
+    llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
+        batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+    conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
+        rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+        rewriter.getI64ArrayAttr(convReshapeDims0));
+
+    // Transpose the factored-out stride to the output channels.
+    Value transposeConvVal = rewriter.create<tosa::ConstOp>(
+        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
+        rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
+
+    conv2d = CreateOpAndInfer<tosa::TransposeOp>(
+        rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
+        transposeConvVal);
+
+    // Fuse striding behavior back into width / height.
+    llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
+        batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+    conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
+        rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+        rewriter.getI64ArrayAttr(convReshapeDims1));
+
+    // Slice out the final result.
+    llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
+    llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
+                                            resultTy.getShape().begin());
+    sliceBegin[1] = pad[0];
+    sliceBegin[2] = pad[1];
+
+    auto slice = CreateOpAndInfer<tosa::SliceOp>(
+                     rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
+                     rewriter.getI64ArrayAttr(sliceBegin),
+                     rewriter.getI64ArrayAttr(resultTy.getShape()))
+                     .getResult();
+
+    auto addBias =
+        CreateOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
+
+    rewriter.replaceOp(op, addBias.getResult());
+
+    return success();
+  }
+};
+
+/// Pass that enables broadcast by making all input arrays have the same
+/// number of dimensions. Insert RESHAPE operations to lower rank operand
+struct TosaDecomposeTransposeConv
+    : public TosaDecomposeTransposeConvBase<TosaDecomposeTransposeConv> {
+public:
+  void runOnFunction() override {
+    auto func = getFunction();
+    RewritePatternSet patterns(func.getContext());
+    patterns
+        .insert<TransposeConvDilatedConverter, TransposeConvStridedConverter>(
+            func.getContext());
+    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+  }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaDecomposeTransposeConvPass() {
+  return std::make_unique<TosaDecomposeTransposeConv>();
+}

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 2e25ad975a09e..1cf88f9bc9709 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1719,27 +1719,6 @@ func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tenso
   return
 }
 
-// -----
-
-// CHECK-LABEL: @transpose_conv
-func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
-  // CHECK: linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]
-  // CHECK: linalg.conv_2d_nhwc_hwcf
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 14, 14, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x14x14x4xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_conv_dilated
-func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
-  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0]
-  // CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<3x3x2x4xf32>)
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32>
-  return
-}
-
-
 // -----
 
 // CHECK-LABEL: @resize_nearest

diff  --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
new file mode 100644
index 0000000000000..627622ba796e3
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s
+
+// CHECK-LABEL: @transpose_conv2d
+func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+  // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+  // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
+  %1 = tensor.cast %0 : tensor<2x18x19x5xf32> to tensor<2x?x?x5xf32>
+  return %1 : tensor<2x?x?x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_quantized
+func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
+  // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+  // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
+  return %0 : tensor<2x18x19x5xi32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_dilated
+func @transpose_conv2d_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+  // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+  // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [2, 3], pad = [4, 4, 15, 15], stride = [1, 1]}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x20x29x5xf32>
+  %1 = tensor.cast %0 : tensor<2x20x29x5xf32> to tensor<2x?x?x5xf32>
+  return %1 : tensor<2x?x?x5xf32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_strided
+func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+  // Manipulate the weight matrix to handle striding.
+  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
+  // CHECK-DAG: %[[PADW:.+]]  = "tosa.pad"(%arg1, %[[PADV]])
+  // CHECK-DAG: %[[RESW1:.+]]  = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]}
+  // CHECK-DAG: %[[TRANS:.+]]  = "tosa.transpose"(%[[RESW1]], %[[TRANSV]])
+  // CHECK-DAG: %[[RESW2:.+]]  = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]}
+  // CHECK-DAG: %[[REV1:.+]]  = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64}
+  // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+
+  // Pad out the input matrix to handle the transpose conv.
+  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+  // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]])
+
+  // Manipulate the final shape.
+  // CHECK-DAG: %[[BIAS:.+]]  = "tosa.const"() {value = dense<0.000000e+00> : tensor<30xf32>}
+  // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]}
+  // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]}
+  // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
+  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
+  // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
+  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
+  %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
+  return %1 : tensor<2x?x?x5xf32>
+}
+
+// ----
+
+// CHECK-LABEL: @transpose_conv2d_strided_quantized
+func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
+  // Manipulate the weight matrix to handle striding.
+  // CHECK-DAG: %[[PADV:.+]]  = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[TRANSV:.+]]  = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
+  // CHECK-DAG: %[[PADW:.+]]  = "tosa.pad"(%arg1, %[[PADV]]) {quantization_info = {input_zp = 42 : i32}}
+  // CHECK-DAG: %[[RESW1:.+]]  = "tosa.reshape"(%[[PADW]]) {new_shape = [5, 2, 2, 2, 3, 3]}
+  // CHECK-DAG: %[[TRANS:.+]]  = "tosa.transpose"(%[[RESW1]], %[[TRANSV]])
+  // CHECK-DAG: %[[RESW2:.+]]  = "tosa.reshape"(%[[TRANS]]) {new_shape = [30, 2, 2, 3]}
+  // CHECK-DAG: %[[REV1:.+]]  = "tosa.reverse"(%[[RESW2]]) {axis = 1 : i64}
+  // CHECK-DAG: %[[NEWWEIGHT:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
+
+  // Pad out the input matrix to handle the transpose conv.
+  // CHECK-DAG: %[[PAD:.+]]  = "tosa.const"() {value = dense<{{\[\[}}0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK-DAG: %[[TRANS2:.+]]  = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+  // CHECK-DAG: %[[NEWINPUT:.+]] = "tosa.pad"(%arg0, %[[PAD]]) {quantization_info = {input_zp = -22 : i32}}
+
+  // Manipulate the final shape.
+  // CHECK-DAG: %[[BIAS:.+]]  = "tosa.const"() {value = dense<0> : tensor<30xi32>}
+  // CHECK-DAG: %[[CONV:.+]] = "tosa.conv2d"(%[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]]) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, stride = [1, 1]}
+  // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, 18, 16, 2, 3, 5]}
+  // CHECK-DAG: %[[TRANS_OUT:.+]] = "tosa.transpose"(%[[RESHAPE_OUT_1]], %[[TRANS2]])
+  // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
+  // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
+  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = {input_zp = -22 : i32, weight_zp = 42 : i32}, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+  return %0 : tensor<2x35x47x5xi32>
+}


        


More information about the Mlir-commits mailing list