[Mlir-commits] [mlir] [mlir][tosa] Remove out_shape from transpose_conv2d (PR #129133)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 27 14:21:21 PST 2025
https://github.com/Jerry-Ge created https://github.com/llvm/llvm-project/pull/129133
Removed out_shape from transpose_conv2d to match the latest TOSA v1.0 specification (https://www.mlplatform.org/tosa/tosa_spec.html#_transpose_conv2d).
>From 9bad5062e40e8cfc966bb96a150fa1df11047430 Mon Sep 17 00:00:00 2001
From: Suraj Sudhir <suraj.sudhir at arm.com>
Date: Fri, 22 Mar 2024 23:24:24 +0000
Subject: [PATCH] [mlir][tosa] Remove out_shape from transpose_conv2d
Signed-off-by: Suraj Sudhir <suraj.sudhir at arm.com>
Change-Id: I654a2489572859dafc0d0928cad8b4086ef1ba30
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 4 +---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 -
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 14 +++++------
mlir/test/Dialect/Tosa/invalid.mlir | 11 +--------
.../Tosa/tosa-decompose-transpose-conv.mlir | 23 ++++++++-----------
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 16 ++++++-------
6 files changed, 25 insertions(+), 44 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 23692478755c6..ce17ad9362227 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -148,13 +148,11 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
"::mlir::Value":$weight, "mlir::Value":$bias,
"::mlir::DenseI64ArrayAttr":$outpad,
"::mlir::DenseI64ArrayAttr":$stride,
- "::mlir::DenseI64ArrayAttr":$outputShape,
"::mlir::TypeAttr":$acc_type),
[{
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias,
- outpad, stride,
- outputShape, acc_type);
+ outpad, stride, acc_type);
}]>;
// The tosa.matmul op is also intended to be generated where a fully_connected
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..a9bf18f4af6b9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -408,7 +408,6 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
- Tosa_IntArrayAttr4:$out_shape,
TypeAttrOf<Tosa_AccType>:$acc_type,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..54f9fa917f2e0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -569,15 +569,15 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
/// Handles tosa.transpose_conv2d which has outpad and output shape
/// attributes.
-static void buildTransConvOpWithQuantInfo(
- OpBuilder &builder, OperationState &result, Type outputType, Value input,
- Value weight, Value bias, DenseI64ArrayAttr outpad,
- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
+static void
+buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input, Value weight,
+ Value bias, DenseI64ArrayAttr outpad,
+ DenseI64ArrayAttr stride, TypeAttr accType) {
auto zps = createZPsAsConst(builder, input, weight);
result.addOperands({input, weight, bias, zps.first, zps.second});
result.addAttribute("out_pad", outpad);
result.addAttribute("stride", stride);
- result.addAttribute("out_shape", outputShape);
result.addAttribute("acc_type", accType);
Type finalOutputType = outputType;
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
@@ -2327,9 +2327,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // outputShape is mutable.
- llvm::SmallVector<int64_t> outputShape =
- convertToMlirShape(adaptor.getOutShape());
+ llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..5b928a2489eea 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -168,7 +168,7 @@ func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tens
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
%zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
return %0 : tensor<1x32x32x16xi8>
}
@@ -741,15 +741,6 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
// -----
-// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
-func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
- // expected-error at +1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with exactly 4 elements}}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
- return %0 : tensor<1x32x32x16xf32>
-}
-
-// -----
-
// CHECK-LABEL: test_mul_type_mismatch
func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index bb3c16cf52d63..7235c52c5dd05 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -30,21 +30,17 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
// -----
-// CHECK-LABEL: @transpose_conv2d_quantized_padded
func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
- // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
- // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
- // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32}
- // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
- // CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
- // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
- // CHECK-SAME: stride = array<i64: 1, 1>}
- %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
- %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
+ // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[REV0]] {axis = 2 : i32}
+ // CHECK: tosa.conv2d %arg0, %[[REV1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>, stride = array<i64: 1, 1>}
+ %input_zp = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
acc_type = i32,
out_pad = array<i64: 1, 2, 3, 4>,
- out_shape = array<i64: -1, -1, -1, -1>,
stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32>
return %0 : tensor<2x21x26x5xi32>
}
@@ -160,12 +156,11 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
// CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]]
// CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
- %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8>
- %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() <{value = dense<-103> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() <{value = dense<93> : tensor<1xi8>}> : () -> tensor<1xi8>
%2 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
acc_type = i32,
out_pad = array<i64: 2, 0, 0, 1>,
- out_shape = array<i64: 1, -1, -1, 1>,
stride = array<i64: 1, 2>} :
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index b87e9a78bf144..8a3dbfe17d686 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -907,7 +907,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_out_shape
func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x8x9x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
return
}
@@ -916,7 +916,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_static
func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x18x19x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -925,7 +925,7 @@ func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5
// CHECK-LABEL: @transpose_conv2d_static_strided
func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x33x45x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -934,7 +934,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1:
// CHECK-LABEL: @transpose_conv2d_dynamic_input
func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<?x?x?x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
return
}
@@ -943,7 +943,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten
// CHECK-LABEL: @transpose_conv2d_dynamic_weights
func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x?x?x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -952,7 +952,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t
// CHECK-LABEL: @transpose_conv2d_dynamic_bias
func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<?xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x8x9x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
return
}
@@ -961,14 +961,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens
// CHECK-LABEL: @transpose_conv2d_padded
func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x10x13x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
return
}
// CHECK-LABEL: @transpose_conv2d_strided
func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<1x13x13x1xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
return
}
More information about the Mlir-commits
mailing list