[Mlir-commits] [mlir] b7f4335 - [mlir][tosa] Update TOSA transpose_conv2d to match spec

Robert Suderman llvmlistbot at llvm.org
Fri Jul 1 12:13:00 PDT 2022


Author: Eric Kunze
Date: 2022-07-01T19:10:28Z
New Revision: b7f4335d6a99213a4f84d367325da882124c9fab

URL: https://github.com/llvm/llvm-project/commit/b7f4335d6a99213a4f84d367325da882124c9fab
DIFF: https://github.com/llvm/llvm-project/commit/b7f4335d6a99213a4f84d367325da882124c9fab.diff

LOG: [mlir][tosa] Update TOSA transpose_conv2d to match spec

The TOSA Specification doesn't have a dilation attribute for transpose_conv2d,
and the padding array is of size 4. (top,bottom,left,right).

This change updates the dialect to match the specification, and updates the lit
tests to match the dialect changes.

Differential Revision: https://reviews.llvm.org/D127332

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 373bab35d1785..0f0898321a321 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -134,12 +134,11 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
 // Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
 def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
   (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias,
-       "ArrayAttr":$outpad, "ArrayAttr":$stride, "ArrayAttr":$dilation,
-       "ArrayAttr":$outputShape),
+       "ArrayAttr":$outpad, "ArrayAttr":$stride, "ArrayAttr":$outputShape),
   [{
     buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
                                   input, weight, bias,
-                                  outpad, stride, dilation,
+                                  outpad, stride,
                                   outputShape);
   }]>;
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 66d55d7435263..c88f8729c83da 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -289,9 +289,8 @@ def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
     Tosa_Tensor4D:$filter,
     Tosa_Tensor1D:$bias,
 
-    Tosa_IntArrayAttr2:$out_pad,
+    Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
-    Tosa_IntArrayAttr2:$dilation,
     Tosa_IntArrayAttrUpto4:$out_shape,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
   );

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 89fc4e318334f..93fddb5e5d6c8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -717,15 +717,15 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
 }
 
 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
-static void
-buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
-                              Type outputType, Value input, Value weight,
-                              Value bias, ArrayAttr outpad, ArrayAttr stride,
-                              ArrayAttr dilation, ArrayAttr outputShape) {
+static void buildTransConvOpWithQuantInfo(OpBuilder &builder,
+                                          OperationState &result,
+                                          Type outputType, Value input,
+                                          Value weight, Value bias,
+                                          ArrayAttr outpad, ArrayAttr stride,
+                                          ArrayAttr outputShape) {
   result.addOperands({input, weight, bias});
   result.addAttribute("out_pad", outpad);
   result.addAttribute("stride", stride);
-  result.addAttribute("dilation", dilation);
   result.addAttribute("out_shape", outputShape);
   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
 
@@ -1817,26 +1817,23 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
                          : outputShape[3];
   }
 
-  llvm::SmallVector<int64_t> dilation;
   llvm::SmallVector<int64_t> padding;
   llvm::SmallVector<int64_t> stride;
 
-  getI64Values(adaptor.dilation(), dilation);
   getI64Values(adaptor.out_pad(), padding);
   getI64Values(adaptor.stride(), stride);
 
   if (!ShapedType::isDynamic(inputHeight) &&
       !ShapedType::isDynamic(weightHeight)) {
-    int32_t dilated = (weightHeight - 1) * dilation[0] + 1;
     int32_t calculateSize =
-        (inputHeight - 1) * stride[0] - padding[0] + dilated;
+        (inputHeight - 1) * stride[0] - padding[0] - padding[1] + weightHeight;
     outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
   }
 
   if (!ShapedType::isDynamic(inputWidth) &&
       !ShapedType::isDynamic(weightWidth)) {
-    int32_t dilated = (weightWidth - 1) * dilation[1] + 1;
-    int32_t calculateSize = (inputWidth - 1) * stride[1] - padding[1] + dilated;
+    int32_t calculateSize =
+        (inputWidth - 1) * stride[1] - padding[2] - padding[3] + weightWidth;
     outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
   }
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 120e5024837d4..37faff61eb80e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -79,7 +79,7 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
   return op;
 }
 
-class TransposeConvDilatedConverter
+class TransposeConvNonStridedConverter
     : public OpRewritePattern<tosa::TransposeConv2DOp> {
 public:
   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
@@ -97,11 +97,9 @@ class TransposeConvDilatedConverter
 
     llvm::SmallVector<int64_t> pad;
     llvm::SmallVector<int64_t> stride;
-    llvm::SmallVector<int64_t> dilation;
 
     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
-    getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
 
     // If striding is all 1 we can modify padding and reverse the kernel along
     // the x/y direction to make it a regular convolution. This is much simpler
@@ -113,16 +111,14 @@ class TransposeConvDilatedConverter
         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
       return failure();
 
-    int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
-    int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
-    int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
-    int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
+    int64_t kernelHeight = weightTy.getDimSize(1);
+    int64_t kernelWidth = weightTy.getDimSize(2);
 
     llvm::SmallVector<int64_t> convPad(4, 0);
     convPad[0] = kernelHeight - 1 - pad[0];
-    convPad[2] = kernelWidth - 1 - pad[1];
-    convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
-    convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
+    convPad[1] = kernelHeight - 1 - pad[1];
+    convPad[2] = kernelWidth - 1 - pad[2];
+    convPad[3] = kernelWidth - 1 - pad[3];
 
     auto reverse1 = rewriter.create<tosa::ReverseOp>(
         loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
@@ -134,12 +130,12 @@ class TransposeConvDilatedConverter
       conv2d = rewriter.create<tosa::Conv2DOp>(
           loc, resultTy, input, reverse2, bias,
           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
-          rewriter.getI64ArrayAttr(dilation), *op.quantization_info());
+          rewriter.getI64ArrayAttr({1, 1}), *op.quantization_info());
     } else {
       conv2d = rewriter.create<tosa::Conv2DOp>(
           loc, resultTy, input, reverse2, bias,
           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
-          rewriter.getI64ArrayAttr(dilation));
+          rewriter.getI64ArrayAttr({1, 1}));
     }
 
     rewriter.replaceOp(op, conv2d);
@@ -170,17 +166,13 @@ class TransposeConvStridedConverter
 
     llvm::SmallVector<int64_t> pad;
     llvm::SmallVector<int64_t> stride;
-    llvm::SmallVector<int64_t> dilation;
 
     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
-    getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
 
     // If striding is all 1 we can modify padding and reverse the kernel along
     // the x/y direction to make it a regular convolution. This is much simpler
     // then handling striding....
-    if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
-      return failure();
 
     // If strides are all 1 we dont need to use this one.
     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
@@ -350,7 +342,7 @@ class TransposeConvStridedConverter
     llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
                                             resultTy.getShape().begin());
     sliceBegin[1] = pad[0];
-    sliceBegin[2] = pad[1];
+    sliceBegin[2] = pad[2];
 
     auto slice = createOpAndInfer<tosa::SliceOp>(
                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
@@ -371,6 +363,6 @@ class TransposeConvStridedConverter
 
 void mlir::tosa::populateTosaDecomposeTransposeConv(
     MLIRContext *ctx, RewritePatternSet &patterns) {
-  patterns.add<TransposeConvDilatedConverter>(ctx);
+  patterns.add<TransposeConvNonStridedConverter>(ctx);
   patterns.add<TransposeConvStridedConverter>(ctx);
 }

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 38d356871dc16..0514e7500798d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -75,7 +75,7 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32
 // -----
 /// CHECK-LABEL: transpose_conv2d
 func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
   return %0 : tensor<1x32x32x16xf32>
 }
 
@@ -506,11 +506,11 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
   %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
-  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):  
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = "tosa.add"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
     "tosa.yield"(%1) : (tensor<f32>) -> ()
   },  {
-  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):  
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
     %1 = "tosa.sub"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
     "tosa.yield"(%1) : (tensor<f32>) -> ()
   }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
@@ -522,12 +522,12 @@ func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1
 func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
   %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   %1:3 = "tosa.while_loop"(%0, %0, %arg0) ({
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):  
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
     %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
     %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
     "tosa.yield"(%3) : (tensor<i1>) -> ()
   },  {
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):  
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
     %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
     %4 = "tosa.reshape"(%2) {new_shape = [1]} : (tensor<i32>) -> tensor<1xi32>

diff  --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 90121fef94905..41d295a566d45 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -5,7 +5,7 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x
   // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
   // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
   // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]}
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
   %1 = tensor.cast %0 : tensor<2x18x19x5xf32> to tensor<2x?x?x5xf32>
   return %1 : tensor<2x?x?x5xf32>
 }
@@ -17,23 +17,22 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
   // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
   // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
   // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = [1, 1]}
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
   return %0 : tensor<2x18x19x5xi32>
 }
 
-// ----
+// -----
 
-// CHECK-LABEL: @transpose_conv2d_dilated
-func.func @transpose_conv2d_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+// CHECK-LABEL: @transpose_conv2d_quantized_padded
+func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x7x7x18xi8>, %arg1: tensor<12x3x5x18xi8>, %arg2: tensor<12xi32>) -> (tensor<2x7x7x12xi32>) {
   // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
   // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
-  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [2, 3], pad = [4, 4, 15, 15], stride = [1, 1]}
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x20x29x5xf32>
-  %1 = tensor.cast %0 : tensor<2x20x29x5xf32> to tensor<2x?x?x5xf32>
-  return %1 : tensor<2x?x?x5xf32>
+  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [1, 1, 2, 2], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = [1, 1]}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [1, 1, 2, 2], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x7x7x18xi8>, tensor<12x3x5x18xi8>, tensor<12xi32>) -> tensor<2x7x7x12xi32>
+  return %0 : tensor<2x7x7x12xi32>
 }
 
-// ----
+// -----
 
 // CHECK-LABEL: @transpose_conv2d_strided
 func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
@@ -60,12 +59,12 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
   // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
   // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
   // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
   %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
   return %1 : tensor<2x?x?x5xf32>
 }
 
-// ----
+// -----
 
 // CHECK-LABEL: @transpose_conv2d_strided_quantized
 func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
@@ -92,6 +91,6 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = "tosa.reshape"(%[[TRANS_OUT]]) {new_shape = [2, 36, 48, 5]}
   // CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[RESHAPE_OUT_2]]) {size = [2, 35, 47, 5], start = [0, 0, 0, 0]}
   // CHECK: %[[ADD:.+]] = "tosa.add"(%[[SLICE]], %arg2)
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
   return %0 : tensor<2x35x47x5xi32>
 }

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 98a2cd1c0e31f..2386b3fa8e6ad 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -912,7 +912,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<
 // CHECK-LABEL: @transpose_conv2d_out_shape
 func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
   // CHECK: -> tensor<2x8x9x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, 8, 9, -1], stride = [1, 1]} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, 8, 9, -1], stride = [1, 1]} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32>
   return
 }
 
@@ -921,16 +921,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<
 // CHECK-LABEL: @transpose_conv2d_static
 func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
   // CHECK: -> tensor<2x18x19x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_conv2d_static_dilated
-func.func @transpose_conv2d_static_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
-  // CHECK: -> tensor<2x20x29x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
   return
 }
 
@@ -939,7 +930,7 @@ func.func @transpose_conv2d_static_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1:
 // CHECK-LABEL: @transpose_conv2d_static_strided
 func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
   // CHECK: -> tensor<2x33x45x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
   return
 }
 
@@ -948,7 +939,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1:
 // CHECK-LABEL: @transpose_conv2d_dynamic_input
 func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
   // CHECK: -> tensor<?x?x?x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
   return
 }
 
@@ -957,7 +948,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten
 // CHECK-LABEL: @transpose_conv2d_dynamic_weights
 func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<5xf32>) {
   // CHECK: -> tensor<2x?x?x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
   return
 }
 
@@ -966,7 +957,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t
 // CHECK-LABEL: @transpose_conv2d_dynamic_bias
 func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<?xf32>) {
   // CHECK: -> tensor<2x8x9x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<2x8x9x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>) -> tensor<2x8x9x5xf32>
   return
 }
 
@@ -975,25 +966,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens
 // CHECK-LABEL: @transpose_conv2d_padded
 func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
   // CHECK: -> tensor<2x10x13x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [1, 3], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_conv2d_dilated
-func.func @transpose_conv2d_dilated(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
-  // CHECK: -> tensor<2x12x14x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [3, 2], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x12x14x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [1, 0, 3, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32>
   return
 }
 
-// -----
-
 // CHECK-LABEL: @transpose_conv2d_strided
 func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) {
   // CHECK: -> tensor<1x13x13x1xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [3, 2]} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [3, 2]} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
   return
 }
 
@@ -1125,7 +1105,7 @@ func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : t
 
 // CHECK-LABEL: @while_test
 func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
-  // CHECK:      "tosa.add" 
+  // CHECK:      "tosa.add"
   // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
   %0 = "tosa.add"(%arg0, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
 


        


More information about the Mlir-commits mailing list