[Mlir-commits] [mlir] [mlir][tosa] Shape operation level checks limited to MAX_SHAPE_LEN (PR #175020)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 8 12:09:49 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

As a result of a recent specification change: https://github.com/arm/tosa-specification/pull/29, the level checks for TOSA shape operations are limited to MAX_SHAPE_LEN as opposed to MAX_RANK. The reason for doing so is detailed in the specification commit message.

This change also removes prior code which incorrectly checked all `shapeTypes`, rather than checking the `shapeType` levels of just shape operations, further aligning with the specification.

---
Full diff: https://github.com/llvm/llvm-project/pull/175020.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h (+6-3) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+39-18) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+13-13) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-valid.mlir (+10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index e088eb31338dc..b80232f112b64 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -28,19 +28,22 @@ struct TosaLevel {
   int32_t MAX_LOG2_SIZE = 0;
   int32_t MAX_NESTING = 0;
   int32_t MAX_TENSOR_LIST_SIZE = 0;
+  int32_t MAX_SHAPE_LEN = 0;
 
   bool operator==(const TosaLevel &rhs) {
     return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
            MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
            MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
            MAX_NESTING == rhs.MAX_NESTING &&
-           MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+           MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE &&
+           MAX_SHAPE_LEN == rhs.MAX_SHAPE_LEN;
   }
 };
 
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6,  8192, 8192, 256,
+                                                31, 6,    64,   16};
 static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
-                                              63, 256,        256};
+                                              63, 256,        256,        64};
 
 TargetEnvAttr lookupTargetEnv(Operation *op);
 TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4b9c8b2030a49..02ada0894593f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -228,12 +228,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
       if (type.getRank() > highest_rank)
         return op->emitOpError() << "failed level check: " << operandOrResult
                                  << " rank(shape) <= MAX_RANK";
-    } else if (tosa::shapeType shapeType =
-                   dyn_cast<tosa::shapeType>(typeToCheck)) {
-      if (shapeType.getRank() > highest_rank)
-        return op->emitOpError()
-               << "failed shape type level check: " << typeToCheck
-               << " exceeds MAX_RANK";
     }
     return success();
   }
@@ -255,6 +249,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return levelCheckSize(op, v.getType(), operandOrResult);
   }
 
+  // Perform the Level shape length check on a value.
+  LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck,
+                                      const StringRef operandOrResult) {
+    if (tosa::shapeType shapeType = dyn_cast<tosa::shapeType>(typeToCheck)) {
+      if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN)
+        return op->emitOpError()
+               << "failed shape type level check: " << typeToCheck
+               << " exceeds MAX_SHAPE_LEN";
+    }
+    return success();
+  }
+
   // Level check sizes of all operands and results of the operation.
   template <typename T>
   LogicalResult levelCheckSizes(T tosaOp) {
@@ -288,6 +294,21 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return success();
   }
 
+  // Level check shape lengths of all operands and results of an operation that
+  // are tosa.shape type.
+  template <typename T>
+  LogicalResult levelCheckShapeLens(T tosaOp) {
+    // auto op = tosaOp.getOperation();
+    for (const auto &v : tosaOp.getOperands()) {
+      if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand")))
+        return failure();
+    }
+    if (failed(levelCheckShapeLength(tosaOp, tosaOp.getResult().getType(),
+                                     "result")))
+      return failure();
+    return success();
+  }
+
   // Level check ranks and sizes.
   LogicalResult levelCheckRanksAndSizes(Operation *op);
 
@@ -591,9 +612,9 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
       return failure();                                                        \
   }
 
-#define CHECK_RANKS(tosaOp)                                                    \
+#define CHECK_SHAPE_LEN(tosaOp)                                                \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op))))                   \
+    if (failed(levelCheckShapeLens(cast<tosa::tosaOp##Op>(op))))               \
       return failure();                                                        \
   }
 
@@ -700,21 +721,21 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   // Shape Operators
   CHECK_SIZES(ConstShape);
 
-  // For the following operations, check whether the rank of each operand
-  // is valid given a level.
+  // For the following operations, check whether the shape length of each
+  // operand is valid given a level.
 
   // Shape Operators
-  CHECK_RANKS(AddShape);
-  CHECK_RANKS(ConcatShape);
-  CHECK_RANKS(DivCeilShape);
-  CHECK_RANKS(DivFloorShape);
-  CHECK_RANKS(MulShape);
-  CHECK_RANKS(SliceShape);
-  CHECK_RANKS(SubShape);
+  CHECK_SHAPE_LEN(AddShape);
+  CHECK_SHAPE_LEN(ConcatShape);
+  CHECK_SHAPE_LEN(DivCeilShape);
+  CHECK_SHAPE_LEN(DivFloorShape);
+  CHECK_SHAPE_LEN(MulShape);
+  CHECK_SHAPE_LEN(SliceShape);
+  CHECK_SHAPE_LEN(SubShape);
 
 #undef CHECK_RANKS_AND_SIZES
 #undef CHECK_SIZES
-#undef CHECK_RANKS
+#undef CHECK_SHAPE_LEN
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index ec540bda5e57d..ec5a750f3339d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1
 
 func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
   %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+  // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
   %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
   return %0 : tensor<1x1x1x1x1x1x819xf32>
 }
@@ -1665,22 +1665,22 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
 
 // -----
 
-func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> {
-  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
-  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
-  // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}}
-  %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13>
-  return %c : !tosa.shape<13>
+func.func @test_add_shape_invalid_rank() -> !tosa.shape<17> {
+  %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %c = tosa.add_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+  return %c : !tosa.shape<17>
 }
 
 // -----
 
-func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
-  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
-  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
-  return %c : !tosa.shape<7>
+func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<17> {
+  %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17>
+  // expected-error at +1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}}
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17>
+  return %c : !tosa.shape<17>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index 036a3d4d0ba2e..d2d010d3a0845 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -37,3 +37,13 @@ func.func @test_validate_without_tosa(%arg0: f32) -> f32 {
   %0 = math.asin %arg0 : f32
   return %0 : f32
 }
+
+// -----
+
+// CHECK-LABEL: test_pad_large_input_rank
+func.func @test_pad_large_input_rank(%arg0: tensor<13x21x3x1x1x1xf32>) -> tensor<13x21x3x1x1x1xf32> {
+  %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+  %padding = tosa.const_shape {values = dense<0> : tensor<12xindex>} : () -> !tosa.shape<12>
+  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3x1x1x1xf32>, !tosa.shape<12>, tensor<1xf32>) -> tensor<13x21x3x1x1x1xf32>
+  return %1 : tensor<13x21x3x1x1x1xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/175020


More information about the Mlir-commits mailing list