[Mlir-commits] [mlir] [mlir][tosa] Disallow shape type in function argument/return types (PR #175754)
Luke Hutton
llvmlistbot at llvm.org
Fri Mar 27 05:05:23 PDT 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/175754
>From d707e56553d548a0446f26fd7157a0da86d20516 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 13 Jan 2026 11:45:44 +0000
Subject: [PATCH] [mlir][tosa] Disallow shape type in function argument/return
types
This commit adds an additional check to the TOSA validation pass to
disallow use of shape types in function arguments and return types.
The specification requires these types be tensor types.
Change-Id: I8ec3dc6b8858fe2367b14e59e33e48a3aa2c37e0
---
.../Tosa/Transforms/TosaValidation.cpp | 20 ++++++++++++
mlir/test/Dialect/Tosa/dynamic_extension.mlir | 4 +--
mlir/test/Dialect/Tosa/invalid.mlir | 20 ++++++------
mlir/test/Dialect/Tosa/invalid_extension.mlir | 12 +++----
mlir/test/Dialect/Tosa/level_check.mlir | 32 +++++++++----------
.../tosa-validation-version-1p1-valid.mlir | 20 ++++++------
6 files changed, 64 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 35b4b862dbff7..b2a91b2441dc4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -164,6 +164,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return success();
}
+ LogicalResult applyFunctionSignatureCheck(func::FuncOp op);
LogicalResult applyLevelCheck(Operation *op);
LogicalResult applyAttributeCheck(Operation *op);
@@ -1373,6 +1374,19 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
return success();
}
+LogicalResult TosaValidation::applyFunctionSignatureCheck(func::FuncOp op) {
+ const auto isShapeType = [](Type type) { return isa<tosa::shapeType>(type); };
+ if (llvm::any_of(op.getArgumentTypes(), isShapeType))
+ return op.emitOpError()
+ << "Function argument types must be a tensor type to be TOSA "
+ "compliant, got !tosa.shape type";
+ if (llvm::any_of(op.getResultTypes(), isShapeType))
+ return op.emitOpError()
+ << "Function return types must be a tensor type to be TOSA "
+ "compliant, got !tosa.shape type";
+ return success();
+}
+
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
@@ -1418,6 +1432,12 @@ void TosaValidation::runOnOperation() {
return signalPassFailure();
targetEnv = *maybeTargetEnv;
+ const auto functions = modOp.getOps<func::FuncOp>();
+ if (llvm::any_of(functions, [&](func::FuncOp func) {
+ return failed(applyFunctionSignatureCheck(func));
+ }))
+ return signalPassFailure();
+
modOp.walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index a1329afc3bb03..df1c091d82aee 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -88,8 +88,8 @@ func.func @test_avg_pool2d_non_const_zps(%arg0: tensor<1x32x32x8xf32>, %input_zp
// -----
-func.func @test_slice_shape_non_const_start_size(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> !tosa.shape<3> {
+func.func @test_slice_shape_non_const_start_size(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) {
%0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
%3 = tosa.slice_shape %0, %arg0, %arg1 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
- return %3 : !tosa.shape<3>
+ return
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b7334fb4246a7..1ffdd31a04c51 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1088,26 +1088,26 @@ func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> {
// -----
-func.func @test_const_shape() -> !tosa.shape<4> {
+func.func @test_const_shape() {
// expected-error at +1 {{'tosa.const_shape' op attribute 'values' failed to satisfy constraint: index elements attribute}}
%cst = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4>
- return %cst : !tosa.shape<4>
+ return
}
// -----
-func.func @test_const_shape_values() -> !tosa.shape<5> {
+func.func @test_const_shape_values() {
// expected-error at +1 {{'tosa.const_shape' op expect number of elements in attribute values (4) to be equal to the rank (5) for the result shape type}}
%cst = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5>
- return %cst : !tosa.shape<5>
+ return
}
// -----
-func.func @test_const_shape_values() -> !tosa.shape<4> {
+func.func @test_const_shape_values() {
// expected-error at +1 {{'tosa.const_shape' op expect elements in attribute values with rank 1}}
%cst = tosa.const_shape {values = dense<[[1, 2], [3, 4]]> : tensor<2x2xindex>} : () -> !tosa.shape<4>
- return %cst : !tosa.shape<4>
+ return
}
// -----
@@ -2070,22 +2070,22 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, ten
// -----
-func.func @test_slice_shape_non_const_start(%arg0: tensor<1xi32>) -> !tosa.shape<3> {
+func.func @test_slice_shape_non_const_start(%arg0: tensor<1xi32>) {
%0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
%2 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error at +1 {{'tosa.slice_shape' op expected compile time resolvable constant, but got variable value for operand #1}}
%3 = tosa.slice_shape %0, %arg0, %2 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
- return %3 : !tosa.shape<3>
+ return
}
// -----
-func.func @test_slice_shape_non_const_size(%arg0: tensor<1xi32>) -> !tosa.shape<3> {
+func.func @test_slice_shape_non_const_size(%arg0: tensor<1xi32>) {
%0 = tosa.const_shape {values = dense<[4, 5, 6, 7, 8, 9]> : tensor<6xindex>} : () -> !tosa.shape<6>
%1 = "tosa.const"() {values = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
// expected-error at +1 {{'tosa.slice_shape' op expected compile time resolvable constant, but got variable value for operand #2}}
%3 = tosa.slice_shape %0, %1, %arg0 : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<3>
- return %3 : !tosa.shape<3>
+ return
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 901d865f4caeb..5af96335dce9d 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -587,32 +587,32 @@ func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6
// -----
-func.func @test_mul_shape() -> !tosa.shape<4> {
+func.func @test_mul_shape() {
%a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
%b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error at +1 {{'tosa.mul_shape' op illegal: requires [shape] but not enabled in target}}
%c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
- return %c : !tosa.shape<4>
+ return
}
// -----
-func.func @test_max_shape() -> !tosa.shape<4> {
+func.func @test_max_shape() {
%a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
%b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error at +1 {{'tosa.max_shape' op illegal: requires [shape] but not enabled in target}}
%c = tosa.max_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
- return %c : !tosa.shape<4>
+ return
}
// -----
-func.func @test_min_shape() -> !tosa.shape<4> {
+func.func @test_min_shape() {
%a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
%b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error at +1 {{'tosa.min_shape' op illegal: requires [shape] but not enabled in target}}
%c = tosa.min_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
- return %c : !tosa.shape<4>
+ return
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d061da14bb109..bf1e84e34029d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1134,9 +1134,9 @@ func.func @test_while_loop_tensor_size_invalid(%arg0: tensor<536870912xi32>, %ar
// -----
-func.func @test_const_shape() -> !tosa.shape<4> {
+func.func @test_const_shape() {
%cst = tosa.const_shape {values = dense<[1, 1, 536870912, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
- return %cst : !tosa.shape<4>
+ return
}
// -----
@@ -1665,67 +1665,67 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
// -----
-func.func @test_add_shape_invalid_rank() -> !tosa.shape<17> {
+func.func @test_add_shape_invalid_rank() {
%a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
%b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
// expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
%c = tosa.add_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
- return %c : !tosa.shape<17>
+ return
}
// -----
-func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<17> {
+func.func @test_div_floor_shape_invalid_rank() {
%a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
%b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
// expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
%c = tosa.div_floor_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
- return %c : !tosa.shape<17>
+ return
}
// -----
-func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1> {
+func.func @test_dim(%arg0: tensor<1x2x3x4x5x6x7x8xi32>) {
// expected-error at +1 {{'tosa.dim' op failed level check: operand rank(shape) <= MAX_RANK}}
%0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4x5x6x7x8xi32>) -> !tosa.shape<1>
- return %0 : !tosa.shape<1>
+ return
}
// -----
-func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<17> {
+func.func @test_exp2_shape_invalid_rank() {
%0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
// expected-error at +1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
%1 = tosa.exp2_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
- return %1 : !tosa.shape<17>
+ return
}
// -----
-func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<17> {
+func.func @test_log2_floor_shape_invalid_rank() {
%0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
// expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
%1 = tosa.log2_floor_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
- return %1 : !tosa.shape<17>
+ return
}
// -----
-func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<17> {
+func.func @test_log2_ceil_shape_invalid_rank() {
%0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
// expected-error at +1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
%1 = tosa.log2_ceil_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
- return %1 : !tosa.shape<17>
+ return
}
// -----
-func.func @test_mod_shape_invalid_rank() -> !tosa.shape<17> {
+func.func @test_mod_shape_invalid_rank() {
%a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
%b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
// expected-error at +1 {{'tosa.mod_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
%c = tosa.mod_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
- return %c : !tosa.shape<17>
+ return
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 4a6ab456744db..d0bbf76247dc3 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -352,19 +352,19 @@ func.func @test_dynamic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
// -----
// CHECK-LABEL: test_add_shape
-func.func @test_add_shape() -> !tosa.shape<4> {
+func.func @test_add_shape() {
%a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
%b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
%c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
- return %c : !tosa.shape<4>
+ return
}
// -----
// CHECK-LABEL: test_dim
-func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) -> !tosa.shape<1> {
+func.func @test_dim(%arg0: tensor<1x2x3x4xi32>) {
%0 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x2x3x4xi32>) -> !tosa.shape<1>
- return %0 : !tosa.shape<1>
+ return
}
// -----
@@ -377,28 +377,28 @@ func.func @test_dim_bf16(%0: tensor<6x4x6x9xbf16>) {
// -----
// CHECK-LABEL: test_exp2_shape
-func.func @test_exp2_shape() -> !tosa.shape<4> {
+func.func @test_exp2_shape() {
%a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%b = tosa.exp2_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
- return %b : !tosa.shape<4>
+ return
}
// -----
// CHECK-LABEL: test_log2_ceil_shape
-func.func @test_log2_ceil_shape() -> !tosa.shape<4> {
+func.func @test_log2_ceil_shape() {
%a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%b = tosa.log2_ceil_shape %a : (!tosa.shape<4>) -> !tosa.shape<4>
- return %b : !tosa.shape<4>
+ return
}
// -----
// CHECK-LABEL: test_mod_shape
-func.func @test_mod_shape() -> !tosa.shape<3> {
+func.func @test_mod_shape() {
%a = tosa.const_shape {values = dense<[10, 11, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
%b = tosa.const_shape {values = dense<[3, 5, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
%c = tosa.mod_shape %a, %b : (!tosa.shape<3>, !tosa.shape<3>) -> !tosa.shape<3>
- return %c : !tosa.shape<3>
+ return
}
// -----
More information about the Mlir-commits
mailing list