[Mlir-commits] [mlir] 5821b32 - [mlir][tosa] Fix shape inference for broadcast bias in transpose_conv2d and depthwise_conv2d (#177739)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 26 06:28:46 PST 2026


Author: Yi-Chi Lee
Date: 2026-01-26T14:28:40Z
New Revision: 5821b324f69ddcf381fa4c72bb01b512d02bc2a4

URL: https://github.com/llvm/llvm-project/commit/5821b324f69ddcf381fa4c72bb01b512d02bc2a4
DIFF: https://github.com/llvm/llvm-project/commit/5821b324f69ddcf381fa4c72bb01b512d02bc2a4.diff

LOG: [mlir][tosa] Fix shape inference for broadcast bias in transpose_conv2d and depthwise_conv2d (#177739)

Fix part of #175765 

Correct shape inference for `tosa.transpose_conv2d` and
`tosa.depthwise_conv2d` when the bias tensor has `BC == 1` based on the
specification.
Fix getting the bias shape in `transpose_conv2d` (it uses the shape of
the input tensor before).

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 033feb79405df..6205161599899 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3876,10 +3876,10 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
 
   // Bias shape can describe the output channels.
   ShapeAdaptor biasShape(adaptor.getBias().getType());
-  if (biasShape.hasRank()) {
-    outputShape[3] = ShapedType::isDynamic(outputShape[3])
-                         ? biasShape.getDimSize(0)
-                         : outputShape[3];
+  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
+    int64_t bc = biasShape.getDimSize(0);
+    if (bc != ShapedType::kDynamic && bc != 1)
+      outputShape[3] = bc;
   }
 
   llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
@@ -3943,11 +3943,11 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   }
 
   // Bias shape can describe the output channels.
-  ShapeAdaptor biasShape(adaptor.getInput().getType());
-  if (biasShape.hasRank()) {
-    outputShape[3] = ShapedType::isDynamic(outputShape[3])
-                         ? biasShape.getDimSize(0)
-                         : outputShape[3];
+  ShapeAdaptor biasShape(adaptor.getBias().getType());
+  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
+    int64_t bc = biasShape.getDimSize(0);
+    if (bc != ShapedType::kDynamic && bc != 1)
+      outputShape[3] = bc;
   }
 
   llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index b74540f060cfe..610fdb6d32ad4 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1721,3 +1721,25 @@ func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>,
   %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_dwconv2d_bias_broadcast
+func.func @test_dwconv2d_bias_broadcast(%input: tensor<2x8x9x?xf32>, %weight: tensor<3x3x?x?xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
+  // CHECK: -> tensor<2x6x7x?xf32>
+  %0 = tosa.depthwise_conv2d %input, %weight, %bias, %input_zp, %weight_zp
+       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
+       : (tensor<2x8x9x?xf32>, tensor<3x3x?x?xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tconv2d_bias_broadcast
+func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: tensor<?x3x3x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
+  // CHECK: -> tensor<2x8x9x?xf32>
+  %0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
+       { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
+       : (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+    return
+  }


        


More information about the Mlir-commits mailing list