[Mlir-commits] [mlir] [mlir][tosa] Shape operation level checks limited to MAX_SHAPE_LEN (PR #175020)
Luke Hutton
llvmlistbot at llvm.org
Fri Jan 9 03:27:12 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/175020
>From b9f2c4be54b05464b021875966a42d6ff7028b1b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 7 Jan 2026 13:44:00 +0000
Subject: [PATCH 1/2] [mlir][tosa] Shape operation level checks limited to
MAX_SHAPE_LEN
As a result of a recent specification change:
https://github.com/arm/tosa-specification/pull/29,
the level checks for TOSA shape operations are limited to MAX_SHAPE_LEN
as opposed to MAX_RANK. The reason for doing so is detailed in the
specification commit message.
This change also removes prior code which incorrectly checked all
`shapeTypes`, rather than checking the `shapeType` levels of just
shape operations, futher aligning with the specification.
Change-Id: I466497a14191721e2c484f360c92a00924e59699
---
mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 9 ++-
.../Tosa/Transforms/TosaValidation.cpp | 59 +++++++++++++------
mlir/test/Dialect/Tosa/level_check.mlir | 26 ++++----
.../Dialect/Tosa/tosa-validation-valid.mlir | 10 ++++
4 files changed, 69 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index e088eb31338dc..b80232f112b64 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -28,19 +28,22 @@ struct TosaLevel {
int32_t MAX_LOG2_SIZE = 0;
int32_t MAX_NESTING = 0;
int32_t MAX_TENSOR_LIST_SIZE = 0;
+ int32_t MAX_SHAPE_LEN = 0;
bool operator==(const TosaLevel &rhs) {
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE &&
+ MAX_SHAPE_LEN == rhs.MAX_SHAPE_LEN;
}
};
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256,
+ 31, 6, 64, 16};
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
- 63, 256, 256};
+ 63, 256, 256, 64};
TargetEnvAttr lookupTargetEnv(Operation *op);
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 99eb83ad6d12c..25ca4b6b6f2f8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -228,12 +228,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
if (type.getRank() > highest_rank)
return op->emitOpError() << "failed level check: " << operandOrResult
<< " rank(shape) <= MAX_RANK";
- } else if (tosa::shapeType shapeType =
- dyn_cast<tosa::shapeType>(typeToCheck)) {
- if (shapeType.getRank() > highest_rank)
- return op->emitOpError()
- << "failed shape type level check: " << typeToCheck
- << " exceeds MAX_RANK";
}
return success();
}
@@ -255,6 +249,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return levelCheckSize(op, v.getType(), operandOrResult);
}
+ // Perform the Level shape length check on a value.
+ LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
+ const StringRef operandOrResult) {
+ if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
+ if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
+ return op->emitOpError()
+ << "failed shape type level check: " << typeToCheck
+ << " exceeds MAX_SHAPE_LEN";
+ }
+ return success();
+ }
+
// Level check sizes of all operands and results of the operation.
template <typename T>
LogicalResult levelCheckSizes(T tosaOp) {
@@ -288,6 +294,21 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return success();
}
+ // Level check shape lengths of all operands and results of an operation that
+ // are tosa.shape type.
+ template <typename T>
+ LogicalResult levelCheckShapeLens(T tosaOp) {
+ // auto op = tosaOp.getOperation();
+ for (const auto &v : tosaOp.getOperands()) {
+ if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
+ return failure();
+ }
+ if (failed(levelCheckShapeLength(tosaOp, tosaOp.getResult().getType(),
+ "result")))
+ return failure();
+ return success();
+ }
+
// Level check ranks and sizes.
LogicalResult levelCheckRanksAndSizes(Operation *op);
@@ -591,9 +612,9 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
return failure(); \
}
-#define CHECK_RANKS(tosaOp) \
+#define CHECK_SHAPE_LEN(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op)))) \
+ if (failed(levelCheckShapeLens(cast<tosa::tosaOp##Op>(op)))) \
return failure(); \
}
@@ -700,22 +721,22 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
// Shape Operators
CHECK_SIZES(ConstShape);
- // For the following operations, check whether the rank of each operand
- // is valid given a level.
+ // For the following operations, check whether the shape length of each
+ // operand is valid given a level.
// Shape Operators
- CHECK_RANKS(AddShape);
- CHECK_RANKS(ConcatShape);
- CHECK_RANKS(DivCeilShape);
- CHECK_RANKS(DivFloorShape);
- CHECK_RANKS(ModShape);
- CHECK_RANKS(MulShape);
- CHECK_RANKS(SliceShape);
- CHECK_RANKS(SubShape);
+ CHECK_SHAPE_LEN(AddShape);
+ CHECK_SHAPE_LEN(ConcatShape);
+ CHECK_SHAPE_LEN(DivCeilShape);
+ CHECK_SHAPE_LEN(DivFloorShape);
+ CHECK_SHAPE_LEN(ModShape);
+ CHECK_SHAPE_LEN(MulShape);
+ CHECK_SHAPE_LEN(SliceShape);
+ CHECK_SHAPE_LEN(SubShape);
#undef CHECK_RANKS_AND_SIZES
#undef CHECK_SIZES
-#undef CHECK_RANKS
+#undef CHECK_SHAPE_LEN
return success();
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 5e681c1ef75c8..da400b78395bd 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1
func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
%1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+ // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
%0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
return %0 : tensor<1x1x1x1x1x1x819xf32>
}
@@ -1665,22 +1665,22 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
// -----
-func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> {
- %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
- %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
- // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}}
- %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13>
- return %c : !tosa.shape<13>
+func.func @test_add_shape_invalid_rank() -> !tosa.shape<17> {
+ %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %c = tosa.add_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+ return %c : !tosa.shape<17>
}
// -----
-func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
- %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
- %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
- // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
- %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
- return %c : !tosa.shape<7>
+func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<17> {
+ %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+ // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+ %c = tosa.div_floor_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+ return %c : !tosa.shape<17>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index 036a3d4d0ba2e..d2d010d3a0845 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -37,3 +37,13 @@ func.func @test_validate_without_tosa(%arg0: f32) -> f32 {
%0 = math.asin %arg0 : f32
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: test_pad_large_input_rank
+func.func @test_pad_large_input_rank(%arg0: tensor<13x21x3x1x1x1xf32>) -> tensor<13x21x3x1x1x1xf32> {
+ %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<0> : tensor<12xindex>} : () -> !tosa.shape<12>
+ %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3x1x1x1xf32>, !tosa.shape<12>, tensor<1xf32>) -> tensor<13x21x3x1x1x1xf32>
+ return %1 : tensor<13x21x3x1x1x1xf32>
+}
>From 84e03e0191a85d9b5a3077ecd75b10326ef4411b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 9 Jan 2026 11:24:10 +0000
Subject: [PATCH 2/2] Remove commented line and rename function
Change-Id: I20b2ece8368cff09df20481ae335c84820a37d59
---
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 25ca4b6b6f2f8..3ef5ea76e7da1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -297,8 +297,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
// Level check shape lengths of all operands and results of an operation that
// are tosa.shape type.
template <typename T>
- LogicalResult levelCheckShapeLens(T tosaOp) {
- // auto op = tosaOp.getOperation();
+ LogicalResult levelCheckShapeLengths(T tosaOp) {
for (const auto &v : tosaOp.getOperands()) {
if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
return failure();
@@ -614,7 +613,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
#define CHECK_SHAPE_LEN(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- if (failed(levelCheckShapeLens(cast<tosa::tosaOp##Op>(op)))) \
+ if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op)))) \
return failure(); \
}
More information about the Mlir-commits
mailing list