[Mlir-commits] [mlir] 044e7e0 - [mlir][tosa] Fixed shape inference for tosa.transpose_conv2d

Rob Suderman llvmlistbot at llvm.org
Wed Nov 17 15:03:36 PST 2021


Author: Rob Suderman
Date: 2021-11-17T14:59:52-08:00
New Revision: 044e7e013ef066590282549f31ab51a64359706b

URL: https://github.com/llvm/llvm-project/commit/044e7e013ef066590282549f31ab51a64359706b
DIFF: https://github.com/llvm/llvm-project/commit/044e7e013ef066590282549f31ab51a64359706b.diff

LOG: [mlir][tosa] Fixed shape inference for tosa.transpose_conv2d

Transpose conv2d shape inference was incorrect, tests did not properly validate
that the shape inference was executing. Corrected shape inference, and extended
tests to actually execute.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D114026

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e32be42a31015..aaba64afb71d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1625,7 +1625,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape = operands.getShape(adaptor.input());
+  ShapeAdaptor weightShape = operands.getShape(adaptor.filter());
   if (weightShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? weightShape.getDimSize(0)

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index edcef6ba698b5..fdf2528f76854 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -919,9 +919,27 @@ func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6
 // -----
 
 // CHECK-LABEL: @transpose_conv2d_static
-func @transpose_conv2d_static(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
-  // CHECK: -> tensor<2x8x9x5xf32>
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32>
+func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
+  // CHECK: -> tensor<2x18x19x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_static_dilated
+func @transpose_conv2d_static_dilated(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
+  // CHECK: -> tensor<2x20x29x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 3], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_static_strided
+func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) {
+  // CHECK: -> tensor<2x33x45x5xf32>
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32>
   return
 }
 


        


More information about the Mlir-commits mailing list