[Mlir-commits] [mlir] 47c255b - Revert "[mlir][tosa] Add several level checks (#128074)" (#129549)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 3 08:39:22 PST 2025


Author: Luke Hutton
Date: 2025-03-03T08:39:18-08:00
New Revision: 47c255b3e7291fd8a7a6fb9d2a183eaad75d5adb

URL: https://github.com/llvm/llvm-project/commit/47c255b3e7291fd8a7a6fb9d2a183eaad75d5adb
DIFF: https://github.com/llvm/llvm-project/commit/47c255b3e7291fd8a7a6fb9d2a183eaad75d5adb.diff

LOG: Revert "[mlir][tosa] Add several level checks (#128074)" (#129549)

This reverts commit ccf1bfc1d50a70260d200a9137ab7924dac029a8.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Dialect/Tosa/level_check.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4757241783b9b..436890443ca9a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -61,22 +61,17 @@ struct TosaLevel {
   int32_t MAX_KERNEL = 0;
   int32_t MAX_STRIDE = 0;
   int32_t MAX_SCALE = 0;
-  int32_t MAX_LOG2_SIZE = 0;
-  int32_t MAX_NESTING = 0;
-  int32_t MAX_TENSOR_LIST_SIZE = 0;
+
+  // @todo: MAX_LOG2_SIZE value and checks
 
   bool operator==(const TosaLevel &rhs) {
     return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
-           MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
-           MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
-           MAX_NESTING == rhs.MAX_NESTING &&
-           MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+           MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
   }
 };
 
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
-static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
-                                              63, 256,        256};
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
+static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
 
 //===----------------------------------------------------------------------===//
 // TOSA Validation Pass.
@@ -116,7 +111,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(checkConstantOperandPad);
   }
 
-  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
+  bool levelCheckKernel(Operation *op, int32_t v,
+                        const std::string &checkDesc) {
     if (v > tosaLevel.MAX_KERNEL) {
       op->emitOpError() << "failed level check: " << checkDesc;
       return false;
@@ -124,7 +120,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
+  bool levelCheckStride(Operation *op, int32_t v,
+                        const std::string &checkDesc) {
     if (v > tosaLevel.MAX_STRIDE) {
       op->emitOpError() << "failed level check: " << checkDesc;
       return false;
@@ -132,7 +129,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
+  bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
     if (v > tosaLevel.MAX_SCALE) {
       op->emitOpError() << "failed level check: " << checkDesc;
       return false;
@@ -140,253 +137,107 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
-      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
-                        << checkDesc;
-      return false;
-    }
-    return true;
-  }
-
-  template <typename T>
-  bool levelCheckRank(Operation *op, const T &v,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  bool levelCheckRank(Operation *op, const Value &v,
+                      const std::string &checkDesc) {
     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
       if (!type.hasRank()) {
         op->emitOpError() << "failed level check: unranked tensor";
         return false;
       }
-      if (type.getRank() > highest_rank) {
-        op->emitOpError() << "failed level check: " << operandOrResult
-                          << " rank(shape) <= MAX_RANK";
+      if (type.getRank() > tosaLevel.MAX_RANK) {
+        op->emitOpError() << "failed level check: " << checkDesc;
         return false;
       }
     }
     return true;
   }
 
-  // Perform the Level tensor size check
-  bool levelCheckSize(Operation *op, const Value &v,
-                      const StringRef operandOrResult) {
-    if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
-      if (!type.hasRank()) {
-        op->emitOpError() << "failed level check: unranked tensor";
-        return false;
-      }
-      auto shape = type.getShape();
-      for (auto dim : shape) {
-        if (mlir::ShapedType::isDynamic(dim)) {
-          op->emitOpError() << "failed level check: " << operandOrResult
-                            << " shape dimension cannot be dynamic";
+  template <typename T>
+  bool levelCheckRanksFor(Operation *op) {
+    if (dyn_cast<T>(op)) {
+      // level check ranks of all operands and results
+      for (auto v : op->getOperands()) {
+        if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
           return false;
-        }
       }
-
-      int64_t element_bits = type.getElementTypeBitWidth();
-      int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
-      int64_t size = element_bytes * type.getNumElements();
-
-      // According to 1.11. Tensor Definitions of Tosa spec, the value of
-      // tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
-      // defined in 1.7. Levels.
-      // For each tensor, the number of tensor elements multiplied by the
-      // element size in bytes must be representable as a tensor_size_t.
-      const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
-      if (size > max_size) {
-        op->emitOpError()
-            << "failed level check: " << operandOrResult
-            << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
-        return false;
-      }
-    }
-    return true;
-  }
-
-  template <typename T>
-  bool levelCheckSizes(T tosaOp) {
-    // level check sizes of all operands and results
-    auto op = tosaOp.getOperation();
-    for (auto v : op->getOperands()) {
-      if (!levelCheckSize(op, v, "operand"))
-        return false;
-    }
-
-    for (auto v : op->getResults()) {
-      if (!levelCheckSize(op, v, "result"))
-        return false;
-    }
-    return true;
-  }
-
-  template <typename T>
-  bool levelCheckRanks(T tosaOp) {
-    // level check ranks of all operands, attribute and results
-    auto op = tosaOp.getOperation();
-    for (auto v : op->getOperands()) {
-      if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
-        return false;
-    }
-
-    if (!op->getAttrs().empty()) {
-      for (NamedAttribute attr : op->getAttrs()) {
-        if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
-          if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
-            return false;
-        }
+      for (auto v : op->getResults()) {
+        if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
+          return false;
       }
     }
-
-    for (auto v : op->getResults()) {
-      if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
-        return false;
-    }
-    return true;
-  }
-
-  template <>
-  bool levelCheckRanks(tosa::ArgMaxOp tosaOp) {
-    auto op = tosaOp.getOperation();
-    if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
-      return false;
-
-    // rank(output) = rank(input) - 1
-    if (!levelCheckRank(op, tosaOp.getOutput(), "result",
-                        tosaLevel.MAX_RANK - 1))
-      return false;
-
     return true;
   }
 
-  template <>
-  bool levelCheckRanks(tosa::IfOp tosaOp) {
-    auto op = tosaOp.getOperation();
-
-    // Only the condition input has rank limitation.
-    if (!levelCheckRank(op, tosaOp.getCond(), "operand", tosaLevel.MAX_RANK))
-      return false;
-
-    return true;
-  }
+  bool levelCheckRanks(Operation *op) {
+#define CHECK_RANKS_FOR(tosaOp)                                                \
+  if (!levelCheckRanksFor<tosaOp##Op>(op))                                     \
+    return false;
 
-  bool levelCheckRanksAndSizes(Operation *op) {
-#define CHECK_RANKS_AND_SIZES(tosaOp)                                          \
-  if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op)))                          \
-      return false;                                                            \
-    if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op)))                          \
-      return false;                                                            \
-  }
-
-#define CHECK_SIZES(tosaOp)                                                    \
-  if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op)))                          \
-      return false;                                                            \
-  }
-
-    // For the following operators, check whether the rank and size of each
-    // tensor operand is valid in a given Level.
-
-    // Tensor Operators
-    CHECK_RANKS_AND_SIZES(ArgMax);
-    // Activation Functions
-    CHECK_RANKS_AND_SIZES(Clamp);
-    CHECK_RANKS_AND_SIZES(Erf);
-    CHECK_RANKS_AND_SIZES(Sigmoid);
-    CHECK_RANKS_AND_SIZES(Tanh);
-    // Elementwise Binary Operators
-    CHECK_RANKS_AND_SIZES(Add);
-    CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
-    CHECK_RANKS_AND_SIZES(BitwiseAnd);
-    CHECK_RANKS_AND_SIZES(BitwiseOr);
-    CHECK_RANKS_AND_SIZES(BitwiseXor);
-    CHECK_RANKS_AND_SIZES(IntDiv);
-    CHECK_RANKS_AND_SIZES(LogicalAnd);
-    CHECK_RANKS_AND_SIZES(LogicalLeftShift);
-    CHECK_RANKS_AND_SIZES(LogicalRightShift);
-    CHECK_RANKS_AND_SIZES(LogicalOr);
-    CHECK_RANKS_AND_SIZES(LogicalXor);
-    CHECK_RANKS_AND_SIZES(Maximum);
-    CHECK_RANKS_AND_SIZES(Minimum);
-    CHECK_RANKS_AND_SIZES(Mul);
-    CHECK_RANKS_AND_SIZES(Pow);
-    CHECK_RANKS_AND_SIZES(Sub);
-    CHECK_RANKS_AND_SIZES(Table);
-    // Elementwise Unary Operators
-    CHECK_RANKS_AND_SIZES(Abs);
-    CHECK_RANKS_AND_SIZES(BitwiseNot);
-    CHECK_RANKS_AND_SIZES(Ceil);
-    CHECK_RANKS_AND_SIZES(Clz);
-    CHECK_RANKS_AND_SIZES(Cos);
-    CHECK_RANKS_AND_SIZES(Exp);
-    CHECK_RANKS_AND_SIZES(Floor);
-    CHECK_RANKS_AND_SIZES(Log);
-    CHECK_RANKS_AND_SIZES(LogicalNot);
-    CHECK_RANKS_AND_SIZES(Negate);
-    CHECK_RANKS_AND_SIZES(Reciprocal);
-    CHECK_RANKS_AND_SIZES(Rsqrt);
-    CHECK_RANKS_AND_SIZES(Sin);
-    // Elementwise Ternary Operators
-    CHECK_RANKS_AND_SIZES(Select);
-    // Comparison Operators
-    CHECK_RANKS_AND_SIZES(Equal);
-    CHECK_RANKS_AND_SIZES(Greater);
-    CHECK_RANKS_AND_SIZES(GreaterEqual);
-    // Reduction Operators
-    CHECK_RANKS_AND_SIZES(ReduceAll);
-    CHECK_RANKS_AND_SIZES(ReduceAny);
-    CHECK_RANKS_AND_SIZES(ReduceMax);
-    CHECK_RANKS_AND_SIZES(ReduceMin);
-    CHECK_RANKS_AND_SIZES(ReduceProduct);
-    CHECK_RANKS_AND_SIZES(ReduceSum);
-    // Data Layout Operators
-    CHECK_RANKS_AND_SIZES(Concat);
-    CHECK_RANKS_AND_SIZES(Pad);
-    CHECK_RANKS_AND_SIZES(Reshape);
-    CHECK_RANKS_AND_SIZES(Reverse);
-    CHECK_RANKS_AND_SIZES(Slice);
-    CHECK_RANKS_AND_SIZES(Tile);
-    CHECK_RANKS_AND_SIZES(Transpose);
-    // Type Conversion
-    CHECK_RANKS_AND_SIZES(Cast);
-    CHECK_RANKS_AND_SIZES(Rescale);
-    // Control Flow Operators
-    CHECK_RANKS_AND_SIZES(If);
-    // Variable Operators
-    CHECK_RANKS_AND_SIZES(Variable);
-    CHECK_RANKS_AND_SIZES(VariableWrite);
-    CHECK_RANKS_AND_SIZES(VariableRead);
-    // Data Nodes
-    CHECK_RANKS_AND_SIZES(Const);
-    CHECK_RANKS_AND_SIZES(Identity);
-
-    // For the following operators, check whether the size of each tensor
-    // operand is valid in a given Level.
-
-    // Tensor Operators
-    CHECK_SIZES(AvgPool2d);
-    CHECK_SIZES(Conv2D);
-    CHECK_SIZES(Conv3D);
-    CHECK_SIZES(DepthwiseConv2D);
-    CHECK_SIZES(TransposeConv2D);
-    CHECK_SIZES(FFT2d);
-    CHECK_SIZES(MatMul);
-    CHECK_SIZES(MaxPool2d);
-    CHECK_SIZES(RFFT2d);
-    // Scatter/Gather Operators
-    CHECK_SIZES(Gather);
-    CHECK_SIZES(Scatter);
-    // Image Operators
-    CHECK_SIZES(Resize);
-    // Custom Operators
-    CHECK_SIZES(Custom);
-    // Control Flow Operators
-    CHECK_SIZES(While);
-    // Shape Operators
-    CHECK_SIZES(ConstShape);
-
-#undef CHECK_RANKS_AND_SIZES
-#undef CHECK_SIZES
+    // tensor operators:
+    CHECK_RANKS_FOR(ArgMax);
+    // all activation functions:
+    CHECK_RANKS_FOR(Clamp);
+    CHECK_RANKS_FOR(Sigmoid);
+    CHECK_RANKS_FOR(Tanh);
+    // all elementwise binary operators:
+    CHECK_RANKS_FOR(Add);
+    CHECK_RANKS_FOR(ArithmeticRightShift);
+    CHECK_RANKS_FOR(BitwiseAnd);
+    CHECK_RANKS_FOR(BitwiseOr);
+    CHECK_RANKS_FOR(BitwiseXor);
+    CHECK_RANKS_FOR(IntDiv);
+    CHECK_RANKS_FOR(LogicalAnd);
+    CHECK_RANKS_FOR(LogicalLeftShift);
+    CHECK_RANKS_FOR(LogicalRightShift);
+    CHECK_RANKS_FOR(LogicalOr);
+    CHECK_RANKS_FOR(LogicalXor);
+    CHECK_RANKS_FOR(Maximum);
+    CHECK_RANKS_FOR(Minimum);
+    CHECK_RANKS_FOR(Mul);
+    CHECK_RANKS_FOR(Pow);
+    CHECK_RANKS_FOR(Sub);
+    CHECK_RANKS_FOR(Table);
+    // all elementwise unary operators:
+    CHECK_RANKS_FOR(Abs);
+    CHECK_RANKS_FOR(BitwiseNot);
+    CHECK_RANKS_FOR(Ceil);
+    CHECK_RANKS_FOR(Clz);
+    CHECK_RANKS_FOR(Exp);
+    CHECK_RANKS_FOR(Floor);
+    CHECK_RANKS_FOR(Log);
+    CHECK_RANKS_FOR(LogicalNot);
+    CHECK_RANKS_FOR(Negate);
+    CHECK_RANKS_FOR(Reciprocal);
+    CHECK_RANKS_FOR(Rsqrt);
+    // all elementwise ternary operators:
+    CHECK_RANKS_FOR(Select);
+    // all comparison operators:
+    CHECK_RANKS_FOR(Equal);
+    CHECK_RANKS_FOR(Greater);
+    CHECK_RANKS_FOR(GreaterEqual);
+    // all reduction operators:
+    CHECK_RANKS_FOR(ReduceAll);
+    CHECK_RANKS_FOR(ReduceAny);
+    CHECK_RANKS_FOR(ReduceMax);
+    CHECK_RANKS_FOR(ReduceMin);
+    CHECK_RANKS_FOR(ReduceProduct);
+    CHECK_RANKS_FOR(ReduceSum);
+    // all data layout operators:
+    CHECK_RANKS_FOR(Concat);
+    CHECK_RANKS_FOR(Pad);
+    CHECK_RANKS_FOR(Reshape);
+    CHECK_RANKS_FOR(Reverse);
+    CHECK_RANKS_FOR(Slice);
+    CHECK_RANKS_FOR(Tile);
+    CHECK_RANKS_FOR(Transpose);
+    // all type conversion operators:
+    CHECK_RANKS_FOR(Cast);
+    CHECK_RANKS_FOR(Rescale);
+    // all data nodes operators:
+    CHECK_RANKS_FOR(Const);
+    CHECK_RANKS_FOR(Identity);
+
+#undef CHECK_RANKS_FOR
     return true;
   }
 
@@ -535,32 +386,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
-  bool levelCheckListSize(Operation *op) {
-    if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
-      return levelCheckListSize(op, concat.getInput1().size(), "input1");
-    }
-    if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
-      if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
-          !levelCheckListSize(op, custom.getOutputList().size(),
-                              "output_list")) {
-        return false;
-      }
-    }
-    if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
-      if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
-          !levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
-        return false;
-      }
-    }
-    if (auto w = dyn_cast<tosa::WhileOp>(op)) {
-      if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
-          !levelCheckListSize(op, w.getOutput().size(), "outputs")) {
-        return false;
-      }
-    }
-    return true;
-  }
-
   // configure profile and level values from pass options profileName and
   // levelName
   void configLevelAndProfile() {
@@ -614,6 +439,10 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return success();
   }
 
+  if (!levelCheckRanks(op)) {
+    return failure();
+  }
+
   // additional level checks from spec 0.70
   if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
       !levelCheckConv<tosa::Conv2DOp>(op) ||
@@ -626,15 +455,6 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return failure();
   }
 
-  if (!levelCheckRanksAndSizes(op)) {
-    return failure();
-  }
-
-  // level check MAX_TENSOR_LIST_SIZE
-  if (!levelCheckListSize(op)) {
-    return failure();
-  }
-
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index ddcb78b87c22b..d2958efe1bb24 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,6 +1,10 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
+//--------------------------------------------------------------------------------------------------
+// Enable all supported profiles and extensions to focus the verification of expected level errors.
+//--------------------------------------------------------------------------------------------------
 
-func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow"
+
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
   return %0 : tensor<1x1x1x1x29x4xi32>
@@ -8,312 +12,7 @@ func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tenso
 
 // -----
 
-func.func @test_clamp_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.clamp' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.clamp %arg0 {min_val = -3.40282347E+38 : f32, max_val = 3.40282347E+38 : f32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_erf_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.erf' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.erf %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_sigmoid_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.sigmoid' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.sigmoid %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_tanh_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.tanh' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.tanh %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.add' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_bitwise_and_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.bitwise_and' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.bitwise_and %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_bitwise_or_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.bitwise_or' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.bitwise_or %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_bitwise_xor_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.bitwise_xor' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.bitwise_xor %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_int_div_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.int_div' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.int_div %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_logical_and_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi1>, %arg1: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
-  // expected-error at +1 {{'tosa.logical_and' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.logical_and %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi1>, tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
-  return %0 : tensor<1x1x1x1x13x21x3xi1>
-}
-
-// -----
-
-func.func @test_logical_left_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.logical_left_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_logical_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.logical_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.logical_right_shift %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_logical_or_rank_invalid(%arg0: tensor<1x1x1x1x13x1x3xi1>, %arg1: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
-  // expected-error at +1 {{'tosa.logical_or' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.logical_or %arg0, %arg1 : (tensor<1x1x1x1x13x1x3xi1>, tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
-  return %0 : tensor<1x1x1x1x13x21x3xi1>
-}
-
-// -----
-
-func.func @test_logical_xor_rank_invalid(%arg0: tensor<1x1x1x1x13x1x3xi1>, %arg1: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
-  // expected-error at +1 {{'tosa.logical_xor' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.logical_xor %arg0, %arg1 : (tensor<1x1x1x1x13x1x3xi1>, tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
-  return %0 : tensor<1x1x1x1x13x21x3xi1>
-}
-
-// -----
-
-func.func @test_max_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.maximum' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.maximum %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_min_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x1x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.minimum' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.minimum %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x1x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_mul_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x1x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.mul' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x1x3xf32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_pow_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.pow' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.pow %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x1xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.sub' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.sub %arg0, %arg1 : (tensor<1x1x1x1x1x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> {
-  // expected-error at +1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}}
-    %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16>
-    return %0 : tensor<1x1x1x1x1x1x64xi16>
-}
-
-// -----
-
-func.func @test_abs_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.abs' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.abs %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_bitwise_not_rank_invalid(%arg0: tensor<1x1x1x1x13x21x1xi32>) -> tensor<1x1x1x1x13x21x1xi32> {
-  // expected-error at +1 {{'tosa.bitwise_not' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.bitwise_not %arg0 : (tensor<1x1x1x1x13x21x1xi32>) -> tensor<1x1x1x1x13x21x1xi32>
-  return %0 : tensor<1x1x1x1x13x21x1xi32>
-}
-
-// -----
-
-func.func @test_ceil_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.ceil' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.ceil %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_clz_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.clz' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.clz %arg0 : (tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_cos_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.cos' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.cos %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_exp_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.exp' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.exp %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_floor_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.floor' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.floor %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_log_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.log' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.log %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_logical_not_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> {
-  // expected-error at +1 {{'tosa.logical_not' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.logical_not %arg0 : (tensor<1x1x1x1x1x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1>
-  return %0 : tensor<1x1x1x1x1x21x3xi1>
-}
-
-// -----
-
-func.func @test_negate_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.negate' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.negate %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_reciprocal_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.reciprocal' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.reciprocal %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_rsqrt_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.rsqrt' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.rsqrt %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_sin_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.sin' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.sin %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_select_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1xi1>, %arg1: tensor<1x1x1x1x13x21x3xf32>, %arg2: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  // expected-error at +1 {{'tosa.select' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1x1x1x1x1xi1>, tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_equal_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x1x3xf32>) -> tensor<1x1x1x1x13x21x3xi1> {
-  // expected-error at +1 {{'tosa.equal' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.equal %arg0, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x1x3xf32>) -> tensor<1x1x1x1x13x21x3xi1>
-  return %0 : tensor<1x1x1x1x13x21x3xi1>
-}
-
-// -----
-
-func.func @test_greater_rank_invalid(%arg0: tensor<1x1x1x1x13x21x1xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1> {
-  // expected-error at +1 {{'tosa.greater' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.greater %arg0, %arg1 : (tensor<1x1x1x1x13x21x1xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1>
-  return %0 : tensor<1x1x1x1x13x21x3xi1>
-}
-
-// -----
-
-func.func @test_greater_equal_rank_invalid(%arg0: tensor<1x1x1x1x13x1x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1> {
-  // expected-error at +1 {{'tosa.greater_equal' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.greater_equal %arg0, %arg1 : (tensor<1x1x1x1x13x1x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xi1>
-  return %0 : tensor<1x1x1x1x13x21x3xi1>
-}
-
-// -----
-
-func.func @test_reduce_all_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> {
+func.func @test_reduce_all(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> {
   // expected-error at +1 {{'tosa.reduce_all' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reduce_all"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1>
   return %0 : tensor<1x1x1x1x1x21x3xi1>
@@ -321,7 +20,7 @@ func.func @test_reduce_all_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> te
 
 // -----
 
-func.func @test_reduce_any_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
+func.func @test_reduce_any(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> {
   // expected-error at +1 {{'tosa.reduce_any' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1>
   return %0 : tensor<1x1x1x1x13x21x3xi1>
@@ -329,7 +28,7 @@ func.func @test_reduce_any_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> te
 
 // -----
 
-func.func @test_reduce_max_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_reduce_max(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_max' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
@@ -337,7 +36,7 @@ func.func @test_reduce_max_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> t
 
 // -----
 
-func.func @test_reduce_min_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_reduce_min(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_min' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
@@ -345,7 +44,7 @@ func.func @test_reduce_min_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> t
 
 // -----
 
-func.func @test_reduce_prod_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_reduce_prod(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_product' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reduce_product"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
@@ -353,7 +52,7 @@ func.func @test_reduce_prod_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) ->
 
 // -----
 
-func.func @test_reduce_sum_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_reduce_sum(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reduce_sum' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
@@ -361,7 +60,7 @@ func.func @test_reduce_sum_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> t
 
 // -----
 
-func.func @test_concat_rank_invalid(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32> {
+func.func @test_concat(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32> {
   // expected-error at +1 {{'tosa.concat' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i32} : (tensor<1x1x1x13x21x3x8xf32>, tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32>
   return %0 : tensor<1x1x1x26x21x3x8xf32>
@@ -369,16 +68,7 @@ func.func @test_concat_rank_invalid(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: t
 
 // -----
 
-func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
-  %padding = tosa.const_shape {value = dense<0> : tensor<14xindex>} : () -> !tosa.shape<14>
-  // expected-error at +1 {{'tosa.pad' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.pad %arg0, %padding : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<14>) -> tensor<1x1x1x1x13x21x3xf32>
-  return %0 : tensor<1x1x1x1x13x21x3xf32>
-}
-
-// -----
-
-func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
+func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
   %1 = tosa.const_shape {value = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
   // expected-error at +1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
   %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
@@ -387,7 +77,7 @@ func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1
 
 // -----
 
-func.func @test_reverse_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_reverse(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
   // expected-error at +1 {{'tosa.reverse' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = "tosa.reverse"(%arg0) {axis = 0 : i32} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
   return %0 : tensor<1x1x1x1x13x21x3xf32>
@@ -395,7 +85,7 @@ func.func @test_reverse_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tens
 
 // -----
 // CHECK-LABEL: slice
-func.func @test_slice_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> {
+func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> {
   %0 = tosa.const_shape {value = dense<[0, 0, 0, 0, 6, 8, 0]> : tensor<7xindex>} : () -> !tosa.shape<7>
   %1 = tosa.const_shape {value = dense<[1, 1, 1, 1, 4, 11, 1]> : tensor<7xindex>} : () -> !tosa.shape<7>
   // expected-error at +1 {{'tosa.slice' op failed level check: operand rank(shape) <= MAX_RANK}}
@@ -405,7 +95,7 @@ func.func @test_slice_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor
 
 // -----
 // CHECK-LABEL: tile
-func.func @test_tile_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> {
+func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> {
   %cst = tosa.const_shape { value = dense<[1, 1, 1, 1, 3, 1, 2]> : tensor<7xindex> } : () -> !tosa.shape<7>
   // expected-error at +1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}}
   %0 = tosa.tile %arg0, %cst : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x39x21x6xf32>
@@ -414,7 +104,7 @@ func.func @test_tile_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<
 
 // -----
 
-func.func @test_transpose_rank_invalid(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32> {
+func.func @test_transpose(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32> {
   // expected-error at +1 {{'tosa.transpose' op failed level check: operand rank(shape) <= MAX_RANK}}
   %1 = "tosa.transpose"(%arg0) {perms = array<i32: 2, 0, 1, 3, 4, 5, 6>} : (tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32>
   return %1 : tensor<3x13x21x1x1x1x1xf32>
@@ -422,43 +112,14 @@ func.func @test_transpose_rank_invalid(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> te
 
 // -----
 
-func.func @test_cast_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi16> {
-  // expected-error at +1 {{'tosa.cast' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.cast %arg0 : (tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi16>
-  return %0 : tensor<1x1x1x1x13x21x3xi16>
-}
-
-// -----
-
-func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8> {
-  // expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, multiplier = array<i32: 1073741824>, shift = array<i8: 30>, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8>
-  return %0 : tensor<1x1x1x1x13x21x3xi8>
-}
-
-// -----
 func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
-  // expected-error at +1 {{'tosa.const' op failed level check: attribute rank(shape) <= MAX_RANK}}
+  // expected-error at +1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RA}}
   %0 = "tosa.const"() {value = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32>
   return %0: tensor<1x1x1x1x1x1x1xi32>
 }
 
 // -----
 
-func.func @test_add_rank_valid(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
-  %0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  return %0 : tensor<f32>
-}
-
-// -----
-
-func.func @test_const_rank_valid(%arg0 : tensor<i32>) -> tensor<i32> {
-  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  return %0: tensor<i32>
-}
-
-// -----
-
 func.func @test_const_i2(%arg0 : tensor<1xi2>) {
   // expected-error at +1 {{'tosa.const' op is not profile-aligned: element type 'i2' is not legal}}
   %0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
@@ -491,21 +152,6 @@ func.func @test_const_ui8(%arg0 : tensor<1xui8>) {
 
 // -----
 
-func.func @test_identity_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
-  // expected-error at +1 {{'tosa.identity' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.identity %arg0 : (tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
-  return %0 : tensor<1x1x1x1x13x21x3xi32>
-}
-
-// -----
-
-func.func @test_identity_rank_valid(%arg0: tensor<i32>) -> tensor<i32> {
-  %0 = tosa.identity %arg0 : (tensor<i32>) -> tensor<i32>
-  return %0 : tensor<i32>
-}
-
-// -----
-
 func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
   // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
   %0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
@@ -1043,150 +689,8 @@ func.func @test_resize_scale_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x7970
 
 // -----
 
-func.func @test_tensor_size_valid(%arg0: tensor<1x536870911xf32>) {
-  %0 = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %1 = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %2= tosa.slice %arg0, %0, %1 : (tensor<1x536870911xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xf32>
-  return
-}
-
-// -----
-
-func.func @test_slice_tensor_size_invalid(%arg0: tensor<1x536870912xf32>) {
-  %0 = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2>
-  %1 = tosa.const_shape {value = dense<536870912> : tensor<2xindex>} : () -> !tosa.shape<2>
-  // expected-error at +1 {{'tosa.slice' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %2= tosa.slice %arg0, %0, %1 : (tensor<1x536870912xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x1xf32>
-  return
-}
-
-
-// -----
-
-func.func @test_resize_tensor_size_invalid(%arg0: tensor<1x23178x23178x1xf32>) {
-  %scale = tosa.const_shape { value = dense<[127, 49, 12, 49]> : tensor<4xindex> } : () -> !tosa.shape<4>
-  %offset = tosa.const_shape { value = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
-  %border = tosa.const_shape { value = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
-  // expected-error at +1 {{'tosa.resize' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.resize %arg0, %scale, %offset, %border {mode = "NEAREST_NEIGHBOR"} : (tensor<1x23178x23178x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
-  return
-}
-
-// -----
-
-func.func @test_avg_pool2d_tensor_size_invalid(%arg0: tensor<1x23178x23178x9xf32>) -> tensor<1x23178x23178x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x23178x23178x9xf32>) -> tensor<1x23178x23178x9xf32>
-  return %0 : tensor<1x23178x23178x9xf32>
-}
-
-// -----
-
-func.func @test_conv2d_tensor_size_invalid(%arg0: tensor<1x23178x23178x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x23178x23178x8xf32> {
-  // expected-error at +1 {{'tosa.conv2d' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x23178x23178x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x23178x23178x8xf32>
-  return %0 : tensor<1x23178x23178x8xf32>
-}
-
-// -----
-
-func.func @test_fft2d_tensor_size_invalid(%arg0: tensor<8191x8191x8191xf32>, %arg1: tensor<8191x8191x8191xf32>) -> (tensor<8191x8191x8191xf32>, tensor<8191x8191x8191xf32>) {
-  // expected-error at +1 {{'tosa.fft2d' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<8191x8191x8191xf32>, tensor<8191x8191x8191xf32>) -> (tensor<8191x8191x8191xf32>, tensor<8191x8191x8191xf32>)
-  return %0, %1 : tensor<8191x8191x8191xf32>, tensor<8191x8191x8191xf32>
-}
-
-// -----
-
-func.func @test_rfft2d_tensor_size_invalid(%arg0: tensor<536870912x8x16xf32>) -> (tensor<536870912x8x9xf32>, tensor<536870912x8x9xf32>) {
-  // expected-error at +1 {{'tosa.rfft2d' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0, %1 = tosa.rfft2d %arg0 : (tensor<536870912x8x16xf32>) -> (tensor<536870912x8x9xf32>, tensor<536870912x8x9xf32>)
-  return %0, %1 : tensor<536870912x8x9xf32>, tensor<536870912x8x9xf32>
-}
-
-// -----
-
-func.func @test_matmul_tensor_size_invalid(%arg0: tensor<23178x20000x19xf32>, %arg1: tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32> {
-  // expected-error at +1 {{'tosa.matmul' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.matmul %arg0, %arg1 : (tensor<23178x20000x19xf32>, tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32>
-  return %0 : tensor<23178x20000x28xf32>
-}
-
-// -----
-
-func.func @test_gather_tensor_size_invalid(%arg0: tensor<536870912x21x3xf32>, %arg1: tensor<536870912x26xi32>) -> tensor<536870912x26x3xf32> {
-  // expected-error at +1 {{'tosa.gather' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.gather %arg0, %arg1 : (tensor<536870912x21x3xf32>, tensor<536870912x26xi32>) -> tensor<536870912x26x3xf32>
-  return %0 : tensor<536870912x26x3xf32>
-}
-
-// -----
-
-func.func @test_custom_tensor_size_invalid(%arg0: tensor<536870912xi32>) -> tensor<536870912xi32> {
-  // expected-error at +1 {{'tosa.custom' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<536870912xi32>) -> (tensor<536870912xi32>)
-  return %0 : tensor<536870912xi32>
-}
-
-// -----
-
-func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %arg1: tensor<268435456x26xi32>) -> tensor<268435456x26x3xf32> {
-  // expected-error at +1 {{'tosa.gather' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.gather %arg0, %arg1 : (tensor<268435456x21x3xf32>, tensor<268435456x26xi32>) -> tensor<268435456x26x3xf32>
-  return %0 : tensor<268435456x26x3xf32>
-}
-
-// -----
-
-func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
-  // expected-error at +1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
-  return %0 : tensor<13x210000000x3xf32>
-}
-
-// -----
-
-func.func @test_variable_read_write_tensor_size_invalid() -> () {
-  tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32>
-  // expected-error at +1 {{'tosa.variable.read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.variable.read @stored_var : tensor<536870912xf32>
-  // expected-error at +1 {{'tosa.variable.write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  tosa.variable.write @stored_var, %0 : tensor<536870912xf32>
-  return
-}
-
-// -----
-
-func.func @test_while_loop_tensor_size_invalid(%arg0: tensor<536870912xi32>, %arg1: tensor<i32>) {
-  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // expected-error at +1 {{'tosa.while_loop' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<536870912xi32>) -> (tensor<i32>, tensor<i32>, tensor<536870912xi32>) {
-    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
-    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
-    tosa.yield %2 : tensor<i1>
-  } do {
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<536870912xi32>):
-    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %3 = "tosa.const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
-    %4 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    // expected-error at +1 {{'tosa.add' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-    %5 = tosa.add %arg4, %3 : (tensor<536870912xi32>, tensor<1xi32>) -> tensor<536870912xi32>
-    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    tosa.yield %6, %4, %5 : tensor<i32>, tensor<i32>, tensor<536870912xi32>
-  }
-  return
-}
-
-// -----
-
-func.func @test_const_shape() -> !tosa.shape<4> {
-  %cst = tosa.const_shape { value = dense<[1, 1, 536870912, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
-  return %cst : !tosa.shape<4>
-}
-
-// -----
-
-func.func @test_cond_if_rank_valid(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1x1x1x1x1x1x1xf32>, %arg2: tensor<i1>) -> tensor<1x1x1x1x1x1x1xf32> {
+// CHECK-LABEL: @test_cond_if
+func.func @test_cond_if(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1x1x1x1x1x1x1xf32>, %arg2: tensor<i1>) -> tensor<1x1x1x1x1x1x1xf32> {
   %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
   ^bb0(%arg3: tensor<1x1x1x1x1x1x1xf32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
     "tosa.yield"(%arg3) : (tensor<1x1x1x1x1x1x1xf32>) -> ()
@@ -1199,32 +703,6 @@ func.func @test_cond_if_rank_valid(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tens
 
 // -----
 
-func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1: tensor<1x1x1x1x1x1x1x1xf32>, %arg2: tensor<1x1x1x1x1x1x1x1xi1>) -> tensor<1x1x1x1x1x1x1x1xf32> {
-  // expected-error at +1 {{'tosa.cond_if' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
-  ^bb0(%arg3: tensor<1x1x1x1x1x1x1x1xf32>, %arg4: tensor<1x1x1x1x1x1x1x1xf32>):
-    "tosa.yield"(%arg3) : (tensor<1x1x1x1x1x1x1x1xf32>) -> ()
-  },  {
-  ^bb0(%arg3: tensor<1x1x1x1x1x1x1x1xf32>, %arg4: tensor<1x1x1x1x1x1x1x1xf32>):
-    "tosa.yield"(%arg4) : (tensor<1x1x1x1x1x1x1x1xf32>) -> ()
-  }) : (tensor<1x1x1x1x1x1x1x1xi1>, tensor<1x1x1x1x1x1x1x1xf32>, tensor<1x1x1x1x1x1x1x1xf32>) -> tensor<1x1x1x1x1x1x1x1xf32>
-  return %0 : tensor<1x1x1x1x1x1x1x1xf32>
-}
-
-// -----
-
-func.func @test_variable_read_write_rank_invalid() -> () {
-  // expected-error at +1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}}
-  tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error at +1 {{'tosa.variable.read' op failed level check: result rank(shape) <= MAX_RANK}}
-  %0 = tosa.variable.read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error at +1 {{'tosa.variable.write' op failed level check: operand rank(shape) <= MAX_RANK}}
-  tosa.variable.write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
-  return
-}
-
-// -----
-
 // CHECK-LABEL: @test_while_loop
 func.func @test_while_loop(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<i32>) {
   %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -1244,8 +722,8 @@ func.func @test_while_loop(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<i32>)
 
 // -----
 
-// CHECK-LABEL: @test_custom_rank_valid
-func.func @test_custom_rank_valid(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x10xi32> {
+// CHECK-LABEL: @test_custom
+func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x10xi32> {
   %0 = "tosa.custom"(%arg0) {operator_name="custom_test", domain_name="tosa_mlir_test", implementation_attrs=""} :
            (tensor<1x1x1x1x1x1x10xi32>) -> (tensor<1x1x1x1x1x1x10xi32>)
   return %0 : tensor<1x1x1x1x1x1x10xi32>
@@ -1262,218 +740,3 @@ func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
   %2= tosa.slice %arg0, %0, %1 : (tensor<*xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32>
   return
 }
-
-// -----
-
-// CHECK-LABEL: test_concat_tensor_list_size
-func.func @test_concat_tensor_list_size() {
-  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // expected-error at +1 {{'tosa.concat' op failed level check for MAX_TENSOR_LIST_SIZE: input1}}
-  %1= tosa.concat %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0 { axis = 0 : i32 }:
-                  (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  ) -> tensor<65xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_custom_tensor_list_size
-func.func @test_custom_tensor_list_size() {
-  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // expected-error at +1 {{'tosa.custom' op failed level check for MAX_TENSOR_LIST_SIZE: input_list}}
-  %1= tosa.custom %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0 { domain_name = "tosa_mlir_test", operator_name = "custom_test", implementation_attrs = "" }:
-                  (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  ) -> tensor<65xi32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_custom_tensor_list_size_results
-func.func @test_custom_tensor_list_size_results() {
-  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-
-  // expected-error at +1 {{'tosa.custom' op failed level check for MAX_TENSOR_LIST_SIZE: output_list}}
-  %r:65 = tosa.custom %0 { domain_name = "tosa_mlir_test", operator_name = "custom_test", implementation_attrs = "" }:
-                  ( tensor<1xi32> )
-                  -> (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  )
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_if_tensor_list_size
-func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
-  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
-  %1 = "tosa.cond_if"(%arg0,   // condition
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  }) : (
-                    tensor<i1>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  ) -> tensor<1xi32>
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_if_tensor_list_size_outputs
-func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
-  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-
-  // expected-error at +1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %r:65 = "tosa.cond_if"(%arg0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  }) : (tensor<i1>) -> (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  )
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_while_tensor_list_size
-func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
-  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
-  %1:2 = "tosa.while_loop"(%0, %arg0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0
-  ) ({
-  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
-    "tosa.yield"(%3) : (tensor<1xi1>) -> ()
-  },  {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
-  }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
-  ) -> (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>)
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_while_tensor_list_size_outputs
-func.func @test_while_tensor_list_size_outputs(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1xi32>) {
-  %0 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
-  // expected-error at +1 {{'tosa.while_loop' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %1:65 = "tosa.while_loop"(%0, %arg0) ({
-  ^bb0(%arg3: tensor<1xi32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
-    %3 = "tosa.logical_not"(%2) : (tensor<1xi1>) -> tensor<1xi1>
-    "tosa.yield"(%3) : (tensor<1xi1>) -> ()
-  },  {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<1x1x1x1x1x1x1xf32>):
-    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
-    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
-    "tosa.yield"(%3, %arg4) : (tensor<i32>, tensor<1x1x1x1x1x1x1xf32>) -> ()
-  }) : (tensor<1xi32>, tensor<1x1x1x1x1x1x1xf32>) -> ( tensor<i32>, tensor<1x1x1x1x1x1x1xf32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
-  )
-
-  return
-}


        


More information about the Mlir-commits mailing list