[Mlir-commits] [mlir] [mlir][tosa] Shape operation level checks limited to MAX_SHAPE_LEN (PR #175020)

Luke Hutton llvmlistbot at llvm.org
Fri Jan 9 14:04:35 PST 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/175020

>From 573b807e7fb4442110e090f96b2fdb0ad6a9c7b0 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 7 Jan 2026 13:44:00 +0000
Subject: [PATCH 1/3] [mlir][tosa] Shape operation level checks limited to
 MAX_SHAPE_LEN

As a result of a recent specification change:
https://github.com/arm/tosa-specification/pull/29,
the level checks for TOSA shape operations are limited to MAX_SHAPE_LEN
as opposed to MAX_RANK. The reason for doing so is detailed in the
specification commit message.

This change also removes prior code which incorrectly checked all
`shapeTypes`, rather than checking the `shapeType` levels of just
shape operations, futher aligning with the specification.

Change-Id: I466497a14191721e2c484f360c92a00924e59699
---
 mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h |  9 ++-
 .../Tosa/Transforms/TosaValidation.cpp        | 69 ++++++++++++-------
 mlir/test/Dialect/Tosa/level_check.mlir       | 26 +++----
 .../Dialect/Tosa/tosa-validation-valid.mlir   | 10 +++
 4 files changed, 74 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index e088eb31338dc..b80232f112b64 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -28,19 +28,22 @@ struct TosaLevel {
   int32_t MAX_LOG2_SIZE = 0;
   int32_t MAX_NESTING = 0;
   int32_t MAX_TENSOR_LIST_SIZE = 0;
+  int32_t MAX_SHAPE_LEN = 0;
 
   bool operator==(const TosaLevel &rhs) {
     return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
            MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
            MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
            MAX_NESTING == rhs.MAX_NESTING &&
-           MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+           MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE &&
+           MAX_SHAPE_LEN == rhs.MAX_SHAPE_LEN;
   }
 };
 
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6,  8192, 8192, 256,
+                                                31, 6,    64,   16};
 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
-                                              63, 256,        256};
+                                              63, 256,        256,        64};
 
 TargetEnvAttr lookupTargetEnv(Operation *op);
 TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index a900aef04f753..20b11fa8715ed 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -228,12 +228,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
       if (type.getRank() > highest_rank)
         return op->emitOpError() << "failed level check: " << operandOrResult
                                  << " rank(shape) <= MAX_RANK";
-    } else if (tosa::shapeType shapeType =
-                   dyn_cast<tosa::shapeType>(typeToCheck)) {
-      if (shapeType.getRank() > highest_rank)
-        return op->emitOpError()
-               << "failed shape type level check: " << typeToCheck
-               << " exceeds MAX_RANK";
     }
     return success();
   }
@@ -255,6 +249,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return levelCheckSize(op, v.getType(), operandOrResult);
   }
 
+  // Perform the Level shape length check on a value.
+  LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
+                                      const StringRef operandOrResult) {
+    if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
+      if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
+        return op->emitOpError()
+               << "failed shape type level check: " << typeToCheck
+               << " exceeds MAX_SHAPE_LEN";
+    }
+    return success();
+  }
+
   // Level check sizes of all operands and results of the operation.
   template <typename T>
   LogicalResult levelCheckSizes(T tosaOp) {
@@ -288,6 +294,21 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return success();
   }
 
+  // Level check shape lengths of all operands and results of an operation that
+  // are tosa.shape type.
+  template <typename T>
+  LogicalResult levelCheckShapeLens(T tosaOp) {
+    // auto op = tosaOp.getOperation();
+    for (const auto &v : tosaOp.getOperands()) {
+      if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
+        return failure();
+    }
+    if (failed(levelCheckShapeLength(tosaOp, tosaOp.getResult().getType(),
+                                     "result")))
+      return failure();
+    return success();
+  }
+
   // Level check ranks and sizes.
   LogicalResult levelCheckRanksAndSizes(Operation *op);
 
@@ -591,9 +612,9 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
       return failure();                                                        \
   }
 
-#define CHECK_RANKS(tosaOp)                                                    \
+#define CHECK_SHAPE_LEN(tosaOp)                                                \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op))))                   \
+    if (failed(levelCheckShapeLens(cast<tosa::tosaOp##Op>(op))))               \
       return failure();                                                        \
   }
 
@@ -700,27 +721,27 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   // Shape Operators
   CHECK_SIZES(ConstShape);
 
-  // For the following operations, check whether the rank of each operand
-  // is valid given a level.
+  // For the following operations, check whether the shape length of each
+  // operand is valid given a level.
 
   // Shape Operators
-  CHECK_RANKS(AddShape);
-  CHECK_RANKS(ConcatShape);
-  CHECK_RANKS(DivCeilShape);
-  CHECK_RANKS(DivFloorShape);
-  CHECK_RANKS(Exp2Shape);
-  CHECK_RANKS(Log2CeilShape);
-  CHECK_RANKS(Log2FloorShape);
-  CHECK_RANKS(MaxShape);
-  CHECK_RANKS(MinShape);
-  CHECK_RANKS(ModShape);
-  CHECK_RANKS(MulShape);
-  CHECK_RANKS(SliceShape);
-  CHECK_RANKS(SubShape);
+  CHECK_SHAPE_LEN(AddShape);
+  CHECK_SHAPE_LEN(ConcatShape);
+  CHECK_SHAPE_LEN(DivCeilShape);
+  CHECK_SHAPE_LEN(DivFloorShape);
+  CHECK_SHAPE_LEN(Exp2Shape);
+  CHECK_SHAPE_LEN(Log2CeilShape);
+  CHECK_SHAPE_LEN(Log2FloorShape);
+  CHECK_SHAPE_LEN(MaxShape);
+  CHECK_SHAPE_LEN(MinShape);
+  CHECK_SHAPE_LEN(ModShape);
+  CHECK_SHAPE_LEN(MulShape);
+  CHECK_SHAPE_LEN(SliceShape);
+  CHECK_SHAPE_LEN(SubShape);
 
 #undef CHECK_RANKS_AND_SIZES
 #undef CHECK_SIZES
-#undef CHECK_RANKS
+#undef CHECK_SHAPE_LEN
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d874ab6d23a50..47de248b375e5 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1
 
 func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
   %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+  // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
   %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
   return %0 : tensor<1x1x1x1x1x1x819xf32>
 }
@@ -1665,22 +1665,22 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
 
 // -----
 
-func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> {
-  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
-  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
-  // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}}
-  %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13>
-  return %c : !tosa.shape<13>
+func.func @test_add_shape_invalid_rank() -> !tosa.shape<17> {
+  %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %c = tosa.add_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+  return %c : !tosa.shape<17>
 }
 
 // -----
 
-func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
-  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
-  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
-  return %c : !tosa.shape<7>
+func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<17> {
+  %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+  return %c : !tosa.shape<17>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index 036a3d4d0ba2e..d2d010d3a0845 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -37,3 +37,13 @@ func.func @test_validate_without_tosa(%arg0: f32) -> f32 {
   %0 = math.asin %arg0 : f32
   return %0 : f32
 }
+
+// -----
+
+// CHECK-LABEL: test_pad_large_input_rank
+func.func @test_pad_large_input_rank(%arg0: tensor<13x21x3x1x1x1xf32>) -> tensor<13x21x3x1x1x1xf32> {
+  %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+  %padding = tosa.const_shape {values = dense<0> : tensor<12xindex>} : () -> !tosa.shape<12>
+  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3x1x1x1xf32>, !tosa.shape<12>, tensor<1xf32>) -> tensor<13x21x3x1x1x1xf32>
+  return %1 : tensor<13x21x3x1x1x1xf32>
+}

>From c8e3535a0b54df06053a48d2681eddf7ac9d7462 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 9 Jan 2026 11:24:10 +0000
Subject: [PATCH 2/3] Remove commented line and rename function

Change-Id: I20b2ece8368cff09df20481ae335c84820a37d59
---
 mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 20b11fa8715ed..de795c6bd9088 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -297,8 +297,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   // Level check shape lengths of all operands and results of an operation that
   // are tosa.shape type.
   template <typename T>
-  LogicalResult levelCheckShapeLens(T tosaOp) {
-    // auto op = tosaOp.getOperation();
+  LogicalResult levelCheckShapeLengths(T tosaOp) {
     for (const auto &v : tosaOp.getOperands()) {
       if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
         return failure();
@@ -614,7 +613,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
 
 #define CHECK_SHAPE_LEN(tosaOp)                                                \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    if (failed(levelCheckShapeLens(cast<tosa::tosaOp##Op>(op))))               \
+    if (failed(levelCheckShapeLengths(cast<tosa::tosaOp##Op>(op))))            \
       return failure();                                                        \
   }
 

>From f5ec722db49e4b8ebf27a2324875053523cd7177 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 9 Jan 2026 11:36:40 +0000
Subject: [PATCH 3/3] Fixup mod_shape test

Change-Id: I06dadb7911e238540237c50c6f85adf188fd2258
---
 .../Tosa/Transforms/TosaValidation.cpp        |  2 +-
 mlir/test/Dialect/Tosa/level_check.mlir       | 42 +++++++++----------
 2 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index de795c6bd9088..387d38411f0fe 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -298,7 +298,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   // are tosa.shape type.
   template <typename T>
   LogicalResult levelCheckShapeLengths(T tosaOp) {
-    for (const auto &v : tosaOp.getOperands()) {
+    for (const auto &v : tosaOp->getOperands()) {
       if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
         return failure();
     }
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 47de248b375e5..dd5ece417cf9e 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1721,37 +1721,37 @@ func.func @test_concat_shape_invalid_list_size() {
 
 // -----
 
-func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<7> {
-  %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
-  %1 = tosa.exp2_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
-  return %1 : !tosa.shape<7>
+func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<17> {
+  %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %1 = tosa.exp2_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
+  return %1 : !tosa.shape<17>
 }
 
 // -----
 
-func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<7> {
-  %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
-  %1 = tosa.log2_floor_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
-  return %1 : !tosa.shape<7>
+func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<17> {
+  %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %1 = tosa.log2_floor_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
+  return %1 : !tosa.shape<17>
 }
 
 // -----
 
-func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<7> {
-  %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
-  %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7>
-  return %1 : !tosa.shape<7>
+func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<17> {
+  %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17>
+  return %1 : !tosa.shape<17>
 }
 
 // -----
 
-func.func @test_mod_shape_invalid_rank() -> !tosa.shape<9> {
-  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9>
-  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9>
-  // expected-error at +1 {{'tosa.mod_shape' op failed shape type level check: '!tosa.shape<9>' exceeds MAX_RANK}}
-  %c = tosa.mod_shape %a, %b : (!tosa.shape<9>, !tosa.shape<9>) -> !tosa.shape<9>
-  return %c : !tosa.shape<9>
+func.func @test_mod_shape_invalid_rank() -> !tosa.shape<17> {
+  %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.mod_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %c = tosa.mod_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+  return %c : !tosa.shape<17>
 }



More information about the Mlir-commits mailing list