[Mlir-commits] [mlir] [mlir][tosa] Guard pooling shape inference on unranked inputs (PR #177999)
Yi-Chi Lee
llvmlistbot at llvm.org
Mon Jan 26 09:44:30 PST 2026
https://github.com/yichi170 updated https://github.com/llvm/llvm-project/pull/177999
>From f6aedddbc8931d05670bee29f549b3e70f97a3bd Mon Sep 17 00:00:00 2001
From: Yi-Chi Lee <yichi170 at gmail.com>
Date: Mon, 26 Jan 2026 11:24:58 -0600
Subject: [PATCH 1/2] [mlir][tosa] Guard pooling shape inference on unranked
inputs
poolingInferReturnTypes didn't properly guard the unknown rank, triggering an assertion in
ShapeAdaptor::getDimSize.
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 2 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 13 +++++++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6205161599899..c412d788c9b29 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3409,7 +3409,7 @@ static LogicalResult poolingInferReturnTypes(
outputShape.resize(4, ShapedType::kDynamic);
// We only know the rank if the input type is unranked.
- if (!inputShape) {
+ if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 610fdb6d32ad4..5c73e37b47dbb 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1743,3 +1743,16 @@ func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: ten
: (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: test_pool2d_unknown_rank
+func.func @test_pool2d_unknown_rank() {
+ %0 = gpu.block_id x
+ %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %2 = tosa.reduce_sum %1 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1xi32>
+ %3 = tosa.bitwise_or %2, %2 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %4 = tensor.cast %3 : tensor<1xi32> to tensor<*xi32>
+ %5 = tosa.avg_pool2d %4, %1, %1 {acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+ return
+}
>From 8ecf5dbce58dab3d1a38c7597168af1342d5a6a7 Mon Sep 17 00:00:00 2001
From: Yi-Chi Lee <yichi170 at gmail.com>
Date: Mon, 26 Jan 2026 11:43:58 -0600
Subject: [PATCH 2/2] update testcase
---
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 12 ++++--------
1 file changed, 4 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 5c73e37b47dbb..bc5f41b1af304 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1746,13 +1746,9 @@ func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: ten
// -----
-// CHECK-LABEL: test_pool2d_unknown_rank
-func.func @test_pool2d_unknown_rank() {
- %0 = gpu.block_id x
- %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
- %2 = tosa.reduce_sum %1 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1xi32>
- %3 = tosa.bitwise_or %2, %2 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
- %4 = tensor.cast %3 : tensor<1xi32> to tensor<*xi32>
- %5 = tosa.avg_pool2d %4, %1, %1 {acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+// CHECK-LABEL: test_avg_pool2d_unranked_input
+func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi32>) {
+ // CHECK: -> tensor<?x?x?x?xi32>
+ %0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
return
}
More information about the Mlir-commits
mailing list