[Mlir-commits] [mlir] [mlir][tosa] Add several level checks (PR #128074)

TatWai Chong llvmlistbot at llvm.org
Thu Feb 20 14:26:54 PST 2025


https://github.com/tatwaichong updated https://github.com/llvm/llvm-project/pull/128074

>From 10ba9af2763190151d9464372f0e4be37c9a7236 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Mon, 11 Mar 2024 21:56:00 +0000
Subject: [PATCH] [mlir][tosa] Add several level checks

Add the following types of level check to consolidate the level validity
- Add rank level checks for Erf, Cos, and Sin.
- Add MAX_LOG2_SIZE level check: The maximum value is 63 when the
  level is set to "none" and 31 when the level is set to "8K".
- Add MAX_TENSOR_LIST_SIZE level check : The maximum value is 256
  when the level is set to "none" and 64 when the level is set to
  "8K".

Co-authored-by: TatWai Chong <tatwai.chong at arm.com>
Change-Id: I797fafe504219e43950824c04839c7187065fe8e
---
 .../Tosa/Transforms/TosaValidation.cpp        | 220 +++++--
 mlir/test/Dialect/Tosa/level_check.mlir       | 592 +++++++++++++++++-
 2 files changed, 741 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..502534bf7579e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -70,17 +70,22 @@ struct TosaLevel {
   int32_t MAX_KERNEL = 0;
   int32_t MAX_STRIDE = 0;
   int32_t MAX_SCALE = 0;
-
-  // @todo: MAX_LOG2_SIZE value and checks
+  int32_t MAX_LOG2_SIZE = 0;
+  int32_t MAX_NESTING = 0;
+  int32_t MAX_TENSOR_LIST_SIZE = 0;
 
   bool operator==(const TosaLevel &rhs) {
     return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
-           MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
+           MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
+           MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
+           MAX_NESTING == rhs.MAX_NESTING &&
+           MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
   }
 };
 
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
-static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
+                                              63, 256,        256};
 
 //===----------------------------------------------------------------------===//
 // TOSA Validation Pass.
@@ -147,107 +152,150 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  bool levelCheckRank(Operation *op, const Value &v,
-                      const std::string &checkDesc) {
+  bool levelCheckListSize(Operation *op, int32_t v,
+                          const std::string &checkDesc) {
+    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
+      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
+                        << checkDesc;
+      return false;
+    }
+    return true;
+  }
+
+  bool levelCheckRankAndSizes(Operation *op, const Value &v,
+                              const std::string &operandOrResult) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
       if (!type.hasRank()) {
         op->emitOpError() << "failed level check: unranked tensor";
         return false;
       }
       if (type.getRank() > tosaLevel.MAX_RANK) {
-        op->emitOpError() << "failed level check: " << checkDesc;
+        op->emitOpError() << "failed level check: " << operandOrResult
+                          << " rank(shape) <= MAX_RANK";
         return false;
       }
+
+      const int64_t max_dim = (1L << tosaLevel.MAX_LOG2_SIZE) - 1;
+      const int64_t max_size = (1L << (tosaLevel.MAX_LOG2_SIZE + 1)) - 1;
+
+      auto shape = type.getShape();
+      bool has_dynamic = false;
+      for (auto dim : shape) {
+        if (mlir::ShapedType::isDynamic(dim)) {
+          has_dynamic = true;
+          continue;
+        }
+        if (dim > max_dim) {
+          op->emitOpError() << "failed level check: " << operandOrResult
+                            << " shape dimension <= (1<<MAX_LOG2_SIZE) - 1";
+          return false;
+        }
+      }
+      if (!has_dynamic) {
+        int64_t element_bits = type.getElementTypeBitWidth();
+        int64_t element_bytes = std::max(
+            static_cast<int64_t>(1), static_cast<int64_t>(element_bits / 8));
+        int64_t size = element_bytes * type.getNumElements();
+        if (size > max_size) {
+          op->emitOpError()
+              << "failed level check: " << operandOrResult
+              << " tensor size (in bytes) <= (1<<MAX_LOG2_SIZE+1) - 1";
+          return false;
+        }
+      }
     }
     return true;
   }
 
   template <typename T>
-  bool levelCheckRanksFor(Operation *op) {
+  bool levelCheckRanksAndSizesFor(Operation *op) {
     if (dyn_cast<T>(op)) {
       // level check ranks of all operands and results
       for (auto v : op->getOperands()) {
-        if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
+        if (!levelCheckRankAndSizes(op, v, "operand"))
           return false;
       }
       for (auto v : op->getResults()) {
-        if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
+        if (!levelCheckRankAndSizes(op, v, "result"))
           return false;
       }
     }
     return true;
   }
 
-  bool levelCheckRanks(Operation *op) {
-#define CHECK_RANKS_FOR(tosaOp)                                                \
-  if (!levelCheckRanksFor<tosaOp##Op>(op))                                     \
+  bool levelCheckRanksAndSizes(Operation *op) {
+#define CHECK_RANKS_AND_SIZES_FOR(tosaOp)                                      \
+  if (!levelCheckRanksAndSizesFor<tosaOp##Op>(op))                             \
     return false;
 
     // tensor operators:
-    CHECK_RANKS_FOR(ArgMax);
+    CHECK_RANKS_AND_SIZES_FOR(ArgMax);
     // all activation functions:
-    CHECK_RANKS_FOR(Clamp);
-    CHECK_RANKS_FOR(Sigmoid);
-    CHECK_RANKS_FOR(Tanh);
+    CHECK_RANKS_AND_SIZES_FOR(Clamp);
+    CHECK_RANKS_AND_SIZES_FOR(Erf);
+    CHECK_RANKS_AND_SIZES_FOR(Sigmoid);
+    CHECK_RANKS_AND_SIZES_FOR(Tanh);
     // all elementwise binary operators:
-    CHECK_RANKS_FOR(Add);
-    CHECK_RANKS_FOR(ArithmeticRightShift);
-    CHECK_RANKS_FOR(BitwiseAnd);
-    CHECK_RANKS_FOR(BitwiseOr);
-    CHECK_RANKS_FOR(BitwiseXor);
-    CHECK_RANKS_FOR(IntDiv);
-    CHECK_RANKS_FOR(LogicalAnd);
-    CHECK_RANKS_FOR(LogicalLeftShift);
-    CHECK_RANKS_FOR(LogicalRightShift);
-    CHECK_RANKS_FOR(LogicalOr);
-    CHECK_RANKS_FOR(LogicalXor);
-    CHECK_RANKS_FOR(Maximum);
-    CHECK_RANKS_FOR(Minimum);
-    CHECK_RANKS_FOR(Mul);
-    CHECK_RANKS_FOR(Pow);
-    CHECK_RANKS_FOR(Sub);
-    CHECK_RANKS_FOR(Table);
+    CHECK_RANKS_AND_SIZES_FOR(Add);
+    CHECK_RANKS_AND_SIZES_FOR(ArithmeticRightShift);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseAnd);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseOr);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseXor);
+    CHECK_RANKS_AND_SIZES_FOR(IntDiv);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalAnd);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalLeftShift);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalRightShift);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalOr);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalXor);
+    CHECK_RANKS_AND_SIZES_FOR(Maximum);
+    CHECK_RANKS_AND_SIZES_FOR(Minimum);
+    CHECK_RANKS_AND_SIZES_FOR(Mul);
+    CHECK_RANKS_AND_SIZES_FOR(Pow);
+    CHECK_RANKS_AND_SIZES_FOR(Sub);
+    CHECK_RANKS_AND_SIZES_FOR(Table);
     // all elementwise unary operators:
-    CHECK_RANKS_FOR(Abs);
-    CHECK_RANKS_FOR(BitwiseNot);
-    CHECK_RANKS_FOR(Ceil);
-    CHECK_RANKS_FOR(Clz);
-    CHECK_RANKS_FOR(Exp);
-    CHECK_RANKS_FOR(Floor);
-    CHECK_RANKS_FOR(Log);
-    CHECK_RANKS_FOR(LogicalNot);
-    CHECK_RANKS_FOR(Negate);
-    CHECK_RANKS_FOR(Reciprocal);
-    CHECK_RANKS_FOR(Rsqrt);
+    CHECK_RANKS_AND_SIZES_FOR(Abs);
+    CHECK_RANKS_AND_SIZES_FOR(BitwiseNot);
+    CHECK_RANKS_AND_SIZES_FOR(Ceil);
+    CHECK_RANKS_AND_SIZES_FOR(Clz);
+    CHECK_RANKS_AND_SIZES_FOR(Cos);
+    CHECK_RANKS_AND_SIZES_FOR(Exp);
+    CHECK_RANKS_AND_SIZES_FOR(Floor);
+    CHECK_RANKS_AND_SIZES_FOR(Log);
+    CHECK_RANKS_AND_SIZES_FOR(LogicalNot);
+    CHECK_RANKS_AND_SIZES_FOR(Negate);
+    CHECK_RANKS_AND_SIZES_FOR(Reciprocal);
+    CHECK_RANKS_AND_SIZES_FOR(Rsqrt);
+    CHECK_RANKS_AND_SIZES_FOR(Sin);
     // all elementwise ternary operators:
-    CHECK_RANKS_FOR(Select);
+    CHECK_RANKS_AND_SIZES_FOR(Select);
     // all comparison operators:
-    CHECK_RANKS_FOR(Equal);
-    CHECK_RANKS_FOR(Greater);
-    CHECK_RANKS_FOR(GreaterEqual);
+    CHECK_RANKS_AND_SIZES_FOR(Equal);
+    CHECK_RANKS_AND_SIZES_FOR(Greater);
+    CHECK_RANKS_AND_SIZES_FOR(GreaterEqual);
     // all reduction operators:
-    CHECK_RANKS_FOR(ReduceAll);
-    CHECK_RANKS_FOR(ReduceAny);
-    CHECK_RANKS_FOR(ReduceMax);
-    CHECK_RANKS_FOR(ReduceMin);
-    CHECK_RANKS_FOR(ReduceProd);
-    CHECK_RANKS_FOR(ReduceSum);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceAll);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceAny);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceMax);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceMin);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceProd);
+    CHECK_RANKS_AND_SIZES_FOR(ReduceSum);
     // all data layout operators:
-    CHECK_RANKS_FOR(Concat);
-    CHECK_RANKS_FOR(Pad);
-    CHECK_RANKS_FOR(Reshape);
-    CHECK_RANKS_FOR(Reverse);
-    CHECK_RANKS_FOR(Slice);
-    CHECK_RANKS_FOR(Tile);
-    CHECK_RANKS_FOR(Transpose);
+    CHECK_RANKS_AND_SIZES_FOR(Concat);
+    CHECK_RANKS_AND_SIZES_FOR(Pad);
+    CHECK_RANKS_AND_SIZES_FOR(Reshape);
+    CHECK_RANKS_AND_SIZES_FOR(Reverse);
+    CHECK_RANKS_AND_SIZES_FOR(Slice);
+    CHECK_RANKS_AND_SIZES_FOR(Tile);
+    CHECK_RANKS_AND_SIZES_FOR(Transpose);
     // all type conversion operators:
-    CHECK_RANKS_FOR(Cast);
-    CHECK_RANKS_FOR(Rescale);
+    CHECK_RANKS_AND_SIZES_FOR(Cast);
+    CHECK_RANKS_AND_SIZES_FOR(Rescale);
     // all data nodes operators:
-    CHECK_RANKS_FOR(Const);
-    CHECK_RANKS_FOR(Identity);
+    CHECK_RANKS_AND_SIZES_FOR(Const);
+    CHECK_RANKS_AND_SIZES_FOR(Identity);
 
-#undef CHECK_RANKS_FOR
+#undef CHECK_RANKS_AND_SIZES_FOR
     return true;
   }
 
@@ -396,6 +444,32 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  bool levelCheckListSize(Operation *op) {
+    if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
+      return levelCheckListSize(op, concat.getInput1().size(), "input1");
+    }
+    if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
+      if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
+          !levelCheckListSize(op, custom.getOutputList().size(),
+                              "output_list")) {
+        return false;
+      }
+    }
+    if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
+      if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
+          !levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
+        return false;
+      }
+    }
+    if (auto w = dyn_cast<tosa::WhileOp>(op)) {
+      if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
+          !levelCheckListSize(op, w.getOutput().size(), "outputs")) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   // configure profile and level values from pass options profileName and
   // levelName
   void configLevelAndProfile() {
@@ -449,7 +523,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return success();
   }
 
-  if (!levelCheckRanks(op)) {
+  if (!levelCheckRanksAndSizes(op)) {
     return failure();
   }
 
@@ -465,6 +539,11 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return failure();
   }
 
+  // level check MAX_TENSOR_LIST_SIZE
+  if (!levelCheckListSize(op)) {
+    return failure();
+  }
+
   return success();
 }
 
@@ -695,6 +774,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
 }
 
 bool TosaValidation::isValidElementType(Type type) {
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(type))
+    type = quantType.getStorageType();
+
   if (isa<FloatType>(type)) {
     return type.isF32() || type.isF16() || type.isBF16();
   } else if (auto intTy = dyn_cast<IntegerType>(type)) {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..74e71dd8d1c19 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -2,7 +2,7 @@
 // Enable all supported profiles and extensions to focus the verification of expected level errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
 
 func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
@@ -12,6 +12,311 @@ func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x
 
 // -----
 
+func.func @test_clamp(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.clamp' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.clamp %arg0 {min_val = -3.40282347E+38 : f32, max_val = 3.40282347E+38 : f32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_erf(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.erf' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.erf %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_sigmoid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.sigmoid' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.sigmoid %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_tanh(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.tanh' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.tanh %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_add(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.add' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_arithmetic_right_shift(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_bitwise_and(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.bitwise_and' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.bitwise_and %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
+func.func @test_bitwise_or(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.bitwise_or' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.bitwise_or %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
+func.func @test_bitwise_xor(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.bitwise_xor' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.bitwise_xor %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
+func.func @test_int_div(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.int_div' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.int_div %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
+func.func @test_logical_and(%arg0: tensor<1x1x1x1x13x21x3xi1>, %arg1: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
+  // expected-error at +1 {{'tosa.logical_and' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.logical_and %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi1>, tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
+  return %0 : tensor<1x1x1x1x13x21x3xi1>
+}
+
+// -----
+
+func.func @test_logical_left_shift(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.logical_left_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
+func.func @test_logical_right_shift(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.logical_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.logical_right_shift %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
+func.func @test_logical_or(%arg0: tensor<1x1x1x1x13x1x3xi1>, %arg1: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
+  // expected-error at +1 {{'tosa.logical_or' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.logical_or %arg0, %arg1 : (tensor<1x1x1x1x13x1x3xi1>, tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
+  return %0 : tensor<1x1x1x1x13x21x3xi1>
+}
+
+// -----
+
+func.func @test_logical_xor(%arg0: tensor<1x1x1x1x13x1x3xi1>, %arg1: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
+  // expected-error at +1 {{'tosa.logical_xor' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.logical_xor %arg0, %arg1 : (tensor<1x1x1x1x13x1x3xi1>, tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
+  return %0 : tensor<1x1x1x1x13x21x3xi1>
+}
+
+// -----
+
+func.func @test_max(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.maximum' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.maximum %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_min(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x1x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.minimum' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.minimum %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x1x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_mul(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x1x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x1x3xf32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_pow(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.pow' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.pow %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_sub(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.sub' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.sub %arg0, %arg1 : (tensor<1x1x1x1x1x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @main(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513x!quant.uniform<i16:f32, 1.0:0>>) -> tensor<1x1x1x1x1x1x64x!quant.uniform<i16:f32, 1.0:0>> {
+  // expected-error at +1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
+    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513x!quant.uniform<i16:f32, 1.000000e+00>>) -> tensor<1x1x1x1x1x1x64x!quant.uniform<i16:f32, 1.000000e+00>>
+    return %0 : tensor<1x1x1x1x1x1x64x!quant.uniform<i16:f32, 1.0:0>>
+}
+
+// -----
+
+func.func @test_abs(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.abs' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.abs %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_bitwise_not(%arg0: tensor<1x1x1x1x13x21x1xi32>) -> tensor<1x1x1x1x13x21x1xi32> {
+  // expected-error at +1 {{'tosa.bitwise_not' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.bitwise_not %arg0 : (tensor<1x1x1x1x13x21x1xi32>) -> tensor<1x1x1x1x13x21x1xi32>
+  return %0 : tensor<1x1x1x1x13x21x1xi32>
+}
+
+// -----
+
+func.func @test_ceil(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.ceil' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.ceil %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_clz(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.clz' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.clz %arg0 : (tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
+func.func @test_cos(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.cos' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.cos %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_exp(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.exp' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.exp %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_floor(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.floor' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.floor %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_log(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.log' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.log %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_logical_not(%arg0: tensor<1x1x1x1x1x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> {
+  // expected-error at +1 {{'tosa.logical_not' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.logical_not %arg0 : (tensor<1x1x1x1x1x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1>
+  return %0 : tensor<1x1x1x1x1x21x3xi1>
+}
+
+// -----
+
+func.func @test_negate(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.negate' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.negate %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_reciprocal(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.reciprocal' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.reciprocal %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rsqrt(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.rsqrt' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.rsqrt %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_sin(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.sin' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.sin %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_select(%arg0: tensor<1x1x1x1x1x1x1xi1>, %arg1: tensor<1x1x1x1x13x21x3xf32>, %arg2: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  // expected-error at +1 {{'tosa.select' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1x1x1x1x1xi1>, tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
+func.func @test_equal(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x1x3xf32>) -> tensor<1x1x1x1x13x21x3xi1> {
+  // expected-error at +1 {{'tosa.equal' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.equal %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x1x3xf32>) -> tensor<1x1x1x1x13x21x3xi1>
+  return %0 : tensor<1x1x1x1x13x21x3xi1>
+}
+
+// -----
+
+func.func @test_greater(%arg0: tensor<1x1x1x1x13x21x1xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1> {
+  // expected-error at +1 {{'tosa.greater' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.greater %arg0, %arg1 : (tensor<1x1x1x1x13x21x1xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1>
+  return %0 : tensor<1x1x1x1x13x21x3xi1>
+}
+
+// -----
+
+func.func @test_greater_equal(%arg0: tensor<1x1x1x1x13x1x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1> {
+  // expected-error at +1 {{'tosa.greater_equal' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.greater_equal %arg0, %arg1 : (tensor<1x1x1x1x13x1x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1>
+  return %0 : tensor<1x1x1x1x13x21x3xi1>
+}
+
+// -----
+
 func.func @test_reduce_all(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> {
   // expected-error at +1 {{'tosa.reduce_all' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reduce_all"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1>
@@ -68,6 +373,15 @@ func.func @test_concat(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x1
 
 // -----
 
+func.func @test_pad(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+  %padding = tosa.const_shape {value = dense<0> : tensor<14xindex>} : () -> !tosa.shape<14>
+  // expected-error at +1 {{'tosa.pad' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.pad %arg0, %padding : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<14>) -> tensor<1x1x1x1x13x21x3xf32>
+  return %0 : tensor<1x1x1x1x13x21x3xf32>
+}
+
+// -----
+
 func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
   %1 = tosa.const_shape {value = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
   // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
@@ -113,8 +427,24 @@ func.func @test_transpose(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x
 
 // -----
 
+func.func @test_cast(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
+  // expected-error at +1 {{'tosa.cast' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.cast %arg0 : (tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+  return %0 : tensor<1x1x1x1x13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+}
+
+// -----
+
+func.func @test_rescale(%arg0: tensor<1x1x1x1x13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<1x1x1x1x13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+  // expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, multiplier = array<i32: 1073741824>, shift = array<i8: 30>, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<1x1x1x1x13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  return %0 : tensor<1x1x1x1x13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+}
+
+// -----
+
 func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
-  // expected-error at +1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RA}}
+  // expected-error at +1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RANK}}
   %0 = "tosa.const"() {value = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32>
   return %0: tensor<1x1x1x1x1x1x1xi32>
 }
@@ -153,6 +483,14 @@ func.func @test_const_ui8(%arg0 : tensor<1xui8>) {
 
 // -----
 
+func.func @test_identity(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
+  // expected-error at +1 {{'tosa.identity' op failed level check: operand rank(shape) <= MAX_RANK}}
+  %0 = tosa.identity %arg0 : (tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+  return %0 : tensor<1x1x1x1x13x21x3xi32>
+}
+
+// -----
+
 func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
   %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
@@ -750,3 +1088,253 @@ func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
   %2= tosa.slice %arg0, %0, %1 : (tensor<*xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: tensor_dim
+func.func @test_tensor_dim(%arg0: tensor<1x2147483648xf32>) {
+  %0 = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
+
+  // expected-error at +1 {{'tosa.slice' op failed level check: operand shape dimension <= (1<<MAX_LOG2_SIZE) - 1}}
+  %2= tosa.slice %arg0, %0, %1 : (tensor<1x2147483648xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: tensor_size
+func.func @test_tensor_size(%arg0: tensor<1x1073741824xf32>) {
+  %0 = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
+
+  // expected-error at +1 {{'tosa.slice' op failed level check: operand tensor size (in bytes) <= (1<<MAX_LOG2_SIZE+1) - 1}}
+  %2= tosa.slice %arg0, %0, %1 : (tensor<1x1073741824xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: tensor_size
+func.func @test_tensor_size_ok(%arg0: tensor<1x1073741823xf32>) {
+  %0 = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
+
+  %2= tosa.slice %arg0, %0, %1 : (tensor<1x1073741823xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_concat_tensor_list_size
+func.func @test_concat_tensor_list_size() {
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.concat' op failed level check for MAX_TENSOR_LIST_SIZE: input1}}
+  %1= tosa.concat %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0 { axis = 0 : i32 }:
+                  (
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>
+                  ) -> tensor<65xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_custom_tensor_list_size
+func.func @test_custom_tensor_list_size() {
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.custom' op failed level check for MAX_TENSOR_LIST_SIZE: input_list}}
+  %1= tosa.custom %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0 { domain_name = "tosa_mlir_test", operator_name = "custom_test", implementation_attrs = "" }:
+                  (
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>
+                  ) -> tensor<65xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_custom_tensor_list_size_results
+func.func @test_custom_tensor_list_size_results() {
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+
+  // expected-error at +1 {{'tosa.custom' op failed level check for MAX_TENSOR_LIST_SIZE: output_list}}
+  %r:65 = tosa.custom %0 { domain_name = "tosa_mlir_test", operator_name = "custom_test", implementation_attrs = "" }:
+                  ( tensor<1xi32> )
+                  -> (
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>
+                  )
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_if_tensor_list_size
+func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
+  %1 = "tosa.cond_if"(%arg0,   // condition
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0) ({
+  ^bb0(%arg3: tensor<1xi32>):
+    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<1xi32>):
+    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  }) : (
+                    tensor<i1>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>
+                  ) -> tensor<1xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_if_tensor_list_size_outputs
+func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+
+  // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
+  %r:65 = "tosa.cond_if"(%arg0) ({
+  ^bb0(%arg3: tensor<1xi32>):
+    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<1xi32>):
+    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  }) : (tensor<i1>) -> (
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>
+                  )
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_while_tensor_list_size
+func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
+  %1:2 = "tosa.while_loop"(%0, %arg0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0
+  ) ({
+  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+    %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
+    "tosa.yield"(%3) : (tensor<1xi1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
+  }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+  ) -> (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_while_tensor_list_size_outputs
+func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
+  %1:65 = "tosa.while_loop"(%0, %arg0) ({
+  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+    %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
+    "tosa.yield"(%3) : (tensor<1xi1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
+  }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>) -> ( tensor<i32>, tensor<1x1x1x1x1x1x1xf32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+  )
+
+  return
+}



More information about the Mlir-commits mailing list