[Mlir-commits] [mlir] 044e7e0 - [mlir][tosa] Fixed shape inference for tosa.transpose_conv2d
Rob Suderman
llvmlistbot at llvm.org
Wed Nov 17 15:03:36 PST 2021
Author: Rob Suderman
Date: 2021-11-17T14:59:52-08:00
New Revision: 044e7e013ef066590282549f31ab51a64359706b
URL: https://github.com/llvm/llvm-project/commit/044e7e013ef066590282549f31ab51a64359706b
DIFF: https://github.com/llvm/llvm-project/commit/044e7e013ef066590282549f31ab51a64359706b.diff
LOG: [mlir][tosa] Fixed shape inference for tosa.transpose_conv2d
Transpose conv2d shape inference was incorrect, tests did not properly validate
that the shape inference was executing. Corrected shape inference, and extended
tests to actually execute.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D114026
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e32be42a31015..aaba64afb71d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1625,7 +1625,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape = operands.getShape(adaptor.input());
+ ShapeAdaptor weightShape = operands.getShape(adaptor.filter());
if (weightShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? weightShape.getDimSize(0)
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index edcef6ba698b5..fdf2528f76854 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -919,9 +919,27 @@ func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6
// -----
// CHECK-LABEL: @transpose_conv2d_static
-func @transpose_conv2d_static(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
- // CHECK: -> tensor<2x8x9x5xf32>
- %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32>
+func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
+ // CHECK: -> tensor<2x18x19x5xf32>
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_static_dilated
+func @transpose_conv2d_static_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
+ // CHECK: -> tensor<2x20x29x5xf32>
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_static_strided
+func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
+ // CHECK: -> tensor<2x33x45x5xf32>
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
return
}
More information about the Mlir-commits
mailing list