[Mlir-commits] [mlir] [mlir][tosa] Shape operation level checks limited to MAX_SHAPE_LEN (PR #175020)
Luke Hutton
llvmlistbot at llvm.org
Fri Jan 9 14:04:35 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/175020
>From 573b807e7fb4442110e090f96b2fdb0ad6a9c7b0 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 7 Jan 2026 13:44:00 +0000
Subject: [PATCH 1/3] [mlir][tosa] Shape operation level checks limited to
MAX_SHAPE_LEN
As a result of a recent specification change:
https://github.com/arm/tosa-specification/pull/29,
the level checks for TOSA shape operations are limited to MAX_SHAPE_LEN
as opposed to MAX_RANK. The reason for doing so is detailed in the
specification commit message.
This change also removes prior code which incorrectly checked all
`shapeTypes`, rather than checking the `shapeType` levels of just
shape operations, futher aligning with the specification.
Change-Id: I466497a14191721e2c484f360c92a00924e59699
---
mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 9 ++-
.../Tosa/Transforms/TosaValidation.cpp | 69 ++++++++++++-------
mlir/test/Dialect/Tosa/level_check.mlir | 26 +++----
.../Dialect/Tosa/tosa-validation-valid.mlir | 10 +++
4 files changed, 74 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index e088eb31338dc..b80232f112b64 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -28,19 +28,22 @@ struct TosaLevel {
int32_t MAX_LOG2_SIZE = 0;
int32_t MAX_NESTING = 0;
int32_t MAX_TENSOR_LIST_SIZE = 0;
+ int32_t MAX_SHAPE_LEN = 0;
bool operator==(const TosaLevel &rhs) {
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE &&
+ MAX_SHAPE_LEN == rhs.MAX_SHAPE_LEN;
}
};
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256,
+ 31, 6, 64, 16};
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
- 63, 256, 256};
+ 63, 256, 256, 64};
TargetEnvAttr lookupTargetEnv(Operation *op);
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index a900aef04f753..20b11fa8715ed 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -228,12 +228,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
if (type.getRank() > highest_rank)
return op->emitOpError() << "failed level check: " << operandOrResult
<< " rank(shape) <= MAX_RANK";
- } else if (tosa::shapeType shapeType =
- dyn_cast<tosa::shapeType>(typeToCheck)) {
- if (shapeType.getRank() > highest_rank)
- return op->emitOpError()
- << "failed shape type level check: " << typeToCheck
- << " exceeds MAX_RANK";
}
return success();
}
@@ -255,6 +249,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return levelCheckSize(op, v.getType(), operandOrResult);
}
+ // Perform the Level shape length check on a value.
+ LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
+ const StringRef operandOrResult) {
+ if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
+ if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
+ return op->emitOpError()
+ << "failed shape type level check: " << typeToCheck
+ << " exceeds MAX_SHAPE_LEN";
+ }
+ return success();
+ }
+
// Level check sizes of all operands and results of the operation.
template <typename T>
LogicalResult levelCheckSizes(T tosaOp) {
@@ -288,6 +294,21 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return success();
}
+ // Level check shape lengths of all operands and results of an operation that
+ // are tosa.shape type.
+ template <typename T>
+ LogicalResult levelCheckShapeLens(T tosaOp) {
+ // auto op = tosaOp.getOperation();
+ for (const auto &v : tosaOp.getOperands()) {
+ if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
+ return failure();
+ }
+ if (failed(levelCheckShapeLength(tosaOp, tosaOp.getResult().getType(),
+ "result")))
+ return failure();
+ return success();
+ }
+
// Level check ranks and sizes.
LogicalResult levelCheckRanksAndSizes(Operation *op);
@@ -591,9 +612,9 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
return failure(); \
}
-#define CHECK_RANKS(tosaOp) \
+#define CHECK_SHAPE_LEN(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
+ if (failed(levelCheckShapeLens(cast<tosa::tosaOp##Op>(op)))) \
return failure(); \
}
@@ -700,27 +721,27 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
// Shape Operators
CHECK_SIZES(ConstShape);
- // For the following operations, check whether the rank of each operand
- // is valid given a level.
+ // For the following operations, check whether the shape length of each
+ // operand is valid given a level.
// Shape Operators
- CHECK_RANKS(AddShape);
- CHECK_RANKS(ConcatShape);
- CHECK_RANKS(DivCeilShape);
- CHECK_RANKS(DivFloorShape);
- CHECK_RANKS(Exp2Shape);
- CHECK_RANKS(Log2CeilShape);
- CHECK_RANKS(Log2FloorShape);
- CHECK_RANKS(MaxShape);
- CHECK_RANKS(MinShape);
- CHECK_RANKS(ModShape);
- CHECK_RANKS(MulShape);
- CHECK_RANKS(SliceShape);
- CHECK_RANKS(SubShape);
+ CHECK_SHAPE_LEN(AddShape);
+ CHECK_SHAPE_LEN(ConcatShape);
+ CHECK_SHAPE_LEN(DivCeilShape);
+ CHECK_SHAPE_LEN(DivFloorShape);
+ CHECK_SHAPE_LEN(Exp2Shape);
+ CHECK_SHAPE_LEN(Log2CeilShape);
+ CHECK_SHAPE_LEN(Log2FloorShape);
+ CHECK_SHAPE_LEN(MaxShape);
+ CHECK_SHAPE_LEN(MinShape);
+ CHECK_SHAPE_LEN(ModShape);
+ CHECK_SHAPE_LEN(MulShape);
+ CHECK_SHAPE_LEN(SliceShape);
+ CHECK_SHAPE_LEN(SubShape);
#undef CHECK_RANKS_AND_SIZES
#undef CHECK_SIZES
-#undef CHECK_RANKS
+#undef CHECK_SHAPE_LEN
return success();
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d874ab6d23a50..47de248b375e5 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1
func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
%1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+ // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
return %0 : tensor<1x1x1x1x1x1x819xf32>
}
@@ -1665,22 +1665,22 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
// -----
-func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> {
- %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
- %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
- // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}}
- %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13>
- return %c : !tosa.shape<13>
+func.func @test_add_shape_invalid_rank() -> !tosa.shape<17> {
+ %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %c = tosa.add_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+ return %c : !tosa.shape<17>
}
// -----
-func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
- %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
- %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
- %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
- return %c : !tosa.shape<7>
+func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<17> {
+ %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %c = tosa.div_floor_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+ return %c : !tosa.shape<17>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index 036a3d4d0ba2e..d2d010d3a0845 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -37,3 +37,13 @@ func.func @test_validate_without_tosa(%arg0: f32) -> f32 {
%0 = math.asin %arg0 : f32
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: test_pad_large_input_rank
+func.func @test_pad_large_input_rank(%arg0: tensor<13x21x3x1x1x1xf32>) -> tensor<13x21x3x1x1x1xf32> {
+ %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<0> : tensor<12xindex>} : () -> !tosa.shape<12>
+ %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3x1x1x1xf32>, !tosa.shape<12>, tensor<1xf32>) -> tensor<13x21x3x1x1x1xf32>
+ return %1 : tensor<13x21x3x1x1x1xf32>
+}
>From c8e3535a0b54df06053a48d2681eddf7ac9d7462 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 9 Jan 2026 11:24:10 +0000
Subject: [PATCH 2/3] Remove commented line and rename function
Change-Id: I20b2ece8368cff09df20481ae335c84820a37d59
---
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 20b11fa8715ed..de795c6bd9088 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -297,8 +297,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
// Level check shape lengths of all operands and results of an operation that
// are tosa.shape type.
template <typename T>
- LogicalResult levelCheckShapeLens(T tosaOp) {
- // auto op = tosaOp.getOperation();
+ LogicalResult levelCheckShapeLengths(T tosaOp) {
for (const auto &v : tosaOp.getOperands()) {
if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
return failure();
@@ -614,7 +613,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
#define CHECK_SHAPE_LEN(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- if (failed(levelCheckShapeLens(cast<tosa::tosaOp##Op>(op)))) \
+ if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
return failure(); \
}
>From f5ec722db49e4b8ebf27a2324875053523cd7177 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 9 Jan 2026 11:36:40 +0000
Subject: [PATCH 3/3] Fixup mod_shape test
Change-Id: I06dadb7911e238540237c50c6f85adf188fd2258
---
.../Tosa/Transforms/TosaValidation.cpp | 2 +-
mlir/test/Dialect/Tosa/level_check.mlir | 42 +++++++++----------
2 files changed, 22 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index de795c6bd9088..387d38411f0fe 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -298,7 +298,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
// are tosa.shape type.
template <typename T>
LogicalResult levelCheckShapeLengths(T tosaOp) {
- for (const auto &v : tosaOp.getOperands()) {
+ for (const auto &v : tosaOp->getOperands()) {
if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
return failure();
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 47de248b375e5..dd5ece417cf9e 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1721,37 +1721,37 @@ func.func @test_concat_shape_invalid_list_size() {
// -----
-func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<7> {
- %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
- %1 = tosa.exp2_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
- return %1 : !tosa.shape<7>
+func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<17> {
+ %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %1 = tosa.exp2_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
+ return %1 : !tosa.shape<17>
}
// -----
-func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<7> {
- %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
- %1 = tosa.log2_floor_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
- return %1 : !tosa.shape<7>
+func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<17> {
+ %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %1 = tosa.log2_floor_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
+ return %1 : !tosa.shape<17>
}
// -----
-func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<7> {
- %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
- %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
- return %1 : !tosa.shape<7>
+func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<17> {
+ %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
+ return %1 : !tosa.shape<17>
}
// -----
-func.func @test_mod_shape_invalid_rank() -> !tosa.shape<9> {
- %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9>
- %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9>
- // expected-error at +1 {{'tosa.mod_shape' op failed shape type level check: '!tosa.shape<9>' exceeds MAX_RANK}}
- %c = tosa.mod_shape %a, %b : (!tosa.shape<9>, !tosa.shape<9>) -> !tosa.shape<9>
- return %c : !tosa.shape<9>
+func.func @test_mod_shape_invalid_rank() -> !tosa.shape<17> {
+ %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.mod_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %c = tosa.mod_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+ return %c : !tosa.shape<17>
}
More information about the Mlir-commits
mailing list